diff --git a/.github/upgrades/prompts/SemanticKernelToAgentFramework.md b/.github/upgrades/prompts/SemanticKernelToAgentFramework.md index a8c3dcb0a6..44985bba98 100644 --- a/.github/upgrades/prompts/SemanticKernelToAgentFramework.md +++ b/.github/upgrades/prompts/SemanticKernelToAgentFramework.md @@ -839,7 +839,7 @@ var agentOptions = new ChatClientAgentRunOptions(new ChatOptions { MaxOutputTokens = 8000, // Breaking glass to access provider-specific options - RawRepresentationFactory = (_) => new OpenAI.Responses.ResponseCreationOptions() + RawRepresentationFactory = (_) => new OpenAI.Responses.CreateResponseOptions() { ReasoningOptions = new() { diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 30be1b0e8f..21d3aa2ed0 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -32,7 +32,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: persist-credentials: false diff --git a/.github/workflows/dotnet-build-and-test.yml b/.github/workflows/dotnet-build-and-test.yml index 5b44073e3f..692f3e7c45 100644 --- a/.github/workflows/dotnet-build-and-test.yml +++ b/.github/workflows/dotnet-build-and-test.yml @@ -35,19 +35,25 @@ jobs: contents: read pull-requests: read outputs: - dotnetChanges: ${{ steps.filter.outputs.dotnet}} + dotnetChanges: ${{ steps.filter.outputs.dotnet }} + cosmosDbChanges: ${{ steps.filter.outputs.cosmosdb }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: dorny/paths-filter@v3 id: filter with: filters: | dotnet: - 'dotnet/**' + cosmosdb: + - 'dotnet/src/Microsoft.Agents.AI.CosmosNoSql/**' # run only if 'dotnet' files were changed - name: dotnet tests if: steps.filter.outputs.dotnet == 'true' run: echo "Dotnet file" + - name: dotnet CosmosDB tests + if: steps.filter.outputs.cosmosdb == 'true' + run: echo "Dotnet CosmosDB changes" # run only if not 'dotnet' files were changed - name: not dotnet tests if: steps.filter.outputs.dotnet != 'true' @@ -68,7 +74,7 @@ jobs: runs-on: ${{ matrix.os }} environment: ${{ matrix.environment }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: persist-credentials: false sparse-checkout: | @@ -77,6 +83,16 @@ jobs: dotnet python workflow-samples + + # Start Cosmos DB Emulator for all integration tests and only for unit tests when CosmosDB changes happened) + - name: Start Azure Cosmos DB Emulator + if: ${{ runner.os == 'Windows' && (needs.paths-filter.outputs.cosmosDbChanges == 'true' || (github.event_name != 'pull_request' && matrix.integration-tests)) }} + shell: pwsh + run: | + Write-Host "Launching Azure Cosmos DB Emulator" + Import-Module "$env:ProgramFiles\Azure Cosmos DB Emulator\PSModules\Microsoft.Azure.CosmosDB.Emulator" + Start-CosmosDbEmulator -NoUI -Key "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" + echo "COSMOS_EMULATOR_AVAILABLE=true" >> $env:GITHUB_ENV - name: Setup dotnet uses: actions/setup-dotnet@v5.0.1 @@ -123,17 +139,7 @@ jobs: popd popd rm -rf "$TEMP_DIR" - - # Start Cosmos DB Emulator for Cosmos-based unit tests (only on Windows) - - name: Start Azure Cosmos DB Emulator - if: runner.os == 'Windows' - shell: pwsh - run: | - Write-Host "Launching Azure Cosmos DB Emulator" - Import-Module "$env:ProgramFiles\Azure Cosmos DB Emulator\PSModules\Microsoft.Azure.CosmosDB.Emulator" - Start-CosmosDbEmulator -NoUI -Key "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" - echo "COSMOS_EMULATOR_AVAILABLE=true" >> $env:GITHUB_ENV - + - name: Run Unit Tests shell: bash run: | @@ -225,7 +231,7 @@ jobs: - name: Upload coverage report artifact if: matrix.targetFramework == env.COVERAGE_FRAMEWORK - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: name: CoverageReport-${{ matrix.os }}-${{ matrix.targetFramework }}-${{ matrix.configuration }} # Artifact name path: ./TestResults/Reports # Directory containing files to upload diff --git a/.github/workflows/dotnet-format.yml b/.github/workflows/dotnet-format.yml index 757d877028..8d7c9febb7 100644 --- a/.github/workflows/dotnet-format.yml +++ b/.github/workflows/dotnet-format.yml @@ -30,7 +30,7 @@ jobs: steps: - name: Check out code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 persist-credentials: false diff --git a/.github/workflows/markdown-link-check.yml b/.github/workflows/markdown-link-check.yml index 3b015fc6af..5c984c5796 100644 --- a/.github/workflows/markdown-link-check.yml +++ b/.github/workflows/markdown-link-check.yml @@ -19,7 +19,7 @@ jobs: runs-on: ubuntu-22.04 # check out the latest version of the code steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: persist-credentials: false diff --git a/.github/workflows/python-code-quality.yml b/.github/workflows/python-code-quality.yml index dd4c0b57cf..4139d47156 100644 --- a/.github/workflows/python-code-quality.yml +++ b/.github/workflows/python-code-quality.yml @@ -27,7 +27,7 @@ jobs: env: UV_PYTHON: ${{ matrix.python-version }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up python and install the project @@ -39,7 +39,7 @@ jobs: env: # Configure a constant location for the uv cache UV_CACHE_DIR: /tmp/.uv-cache - - uses: actions/cache@v4 + - uses: actions/cache@v5 with: path: ~/.cache/pre-commit key: pre-commit|${{ matrix.python-version }}|${{ hashFiles('python/.pre-commit-config.yaml') }} diff --git a/.github/workflows/python-docs.yml b/.github/workflows/python-docs.yml index b2be4b6ad0..f962ec318f 100644 --- a/.github/workflows/python-docs.yml +++ b/.github/workflows/python-docs.yml @@ -24,7 +24,7 @@ jobs: run: working-directory: python steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Set up uv uses: astral-sh/setup-uv@v7 with: diff --git a/.github/workflows/python-lab-tests.yml b/.github/workflows/python-lab-tests.yml index ae526cf962..f5cb504d04 100644 --- a/.github/workflows/python-lab-tests.yml +++ b/.github/workflows/python-lab-tests.yml @@ -24,7 +24,7 @@ jobs: outputs: pythonChanges: ${{ steps.filter.outputs.python}} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: dorny/paths-filter@v3 id: filter with: @@ -59,7 +59,7 @@ jobs: run: working-directory: python steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Set up python and install the project id: python-setup diff --git a/.github/workflows/python-merge-tests.yml b/.github/workflows/python-merge-tests.yml index a30b3c4ac3..66b9122726 100644 --- a/.github/workflows/python-merge-tests.yml +++ b/.github/workflows/python-merge-tests.yml @@ -28,7 +28,7 @@ jobs: outputs: pythonChanges: ${{ steps.filter.outputs.python}} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: dorny/paths-filter@v3 id: filter with: @@ -75,7 +75,7 @@ jobs: run: working-directory: python steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Set up python and install the project id: python-setup uses: ./.github/actions/python-setup @@ -135,7 +135,7 @@ jobs: run: working-directory: python steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Set up python and install the project id: python-setup uses: ./.github/actions/python-setup diff --git a/.github/workflows/python-release.yml b/.github/workflows/python-release.yml index 97f1ef2481..ba6e3689b0 100644 --- a/.github/workflows/python-release.yml +++ b/.github/workflows/python-release.yml @@ -23,7 +23,7 @@ jobs: run: working-directory: python steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Set up python and install the project id: python-setup uses: ./.github/actions/python-setup diff --git a/.github/workflows/python-test-coverage-report.yml b/.github/workflows/python-test-coverage-report.yml index 3da82cc943..e09d9c8870 100644 --- a/.github/workflows/python-test-coverage-report.yml +++ b/.github/workflows/python-test-coverage-report.yml @@ -19,9 +19,9 @@ jobs: run: working-directory: python steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Download coverage report - uses: actions/download-artifact@v6 + uses: actions/download-artifact@v7 with: github-token: ${{ secrets.GH_ACTIONS_PR_WRITE }} run-id: ${{ github.event.workflow_run.id }} diff --git a/.github/workflows/python-test-coverage.yml b/.github/workflows/python-test-coverage.yml index dd260ba5f6..03cca20e06 100644 --- a/.github/workflows/python-test-coverage.yml +++ b/.github/workflows/python-test-coverage.yml @@ -20,7 +20,7 @@ jobs: env: UV_PYTHON: "3.10" steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 # Save the PR number to a file since the workflow_run event # in the coverage report workflow does not have access to it - name: Save PR number @@ -38,7 +38,7 @@ jobs: - name: Run all tests with coverage report run: uv run poe all-tests-cov --cov-report=xml:python-coverage.xml -q --junitxml=pytest.xml - name: Upload coverage report - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: path: | python/python-coverage.xml diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 697a8ff4a7..07b9200a46 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -27,7 +27,7 @@ jobs: run: working-directory: python steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Set up python and install the project id: python-setup uses: ./.github/actions/python-setup diff --git a/agent-samples/openai/OpenAIResponses.yaml b/agent-samples/openai/OpenAIResponses.yaml index 45f4c30340..bdc04d4a13 100644 --- a/agent-samples/openai/OpenAIResponses.yaml +++ b/agent-samples/openai/OpenAIResponses.yaml @@ -10,19 +10,19 @@ model: temperature: 0.9 topP: 0.95 connection: - kind: ApiKey - key: =Env.OPENAI_API_KEY + kind: key + apiKey: =Env.OPENAI_APIKEY outputSchema: properties: language: - type: string + kind: string required: true description: The language of the answer. answer: - type: string + kind: string required: true description: The answer text. type: - type: string + kind: string required: true description: The type of the response. diff --git a/docs/features/durable-agents/durable-agents-ttl.md b/docs/features/durable-agents/durable-agents-ttl.md new file mode 100644 index 0000000000..1a4a4e32d6 --- /dev/null +++ b/docs/features/durable-agents/durable-agents-ttl.md @@ -0,0 +1,147 @@ +# Time-To-Live (TTL) for durable agent sessions + +## Overview + +The durable agents automatically maintain conversation history and state for each session. Without automatic cleanup, this state can accumulate indefinitely, consuming storage resources and increasing costs. The Time-To-Live (TTL) feature provides automatic cleanup of idle agent sessions, ensuring that sessions are automatically deleted after a period of inactivity. + +## What is TTL? + +Time-To-Live (TTL) is a configurable duration that determines how long an agent session state will be retained after its last interaction. When an agent session is idle (no messages sent to it) for longer than the TTL period, the session state is automatically deleted. Each new interaction with an agent resets the TTL timer, extending the session's lifetime. + +## Benefits + +- **Automatic cleanup**: No manual intervention required to clean up idle agent sessions +- **Cost optimization**: Reduces storage costs by automatically removing unused session state +- **Resource management**: Prevents unbounded growth of agent session state in storage +- **Configurable**: Set TTL globally or per-agent type to match your application's needs + +## Configuration + +TTL can be configured at two levels: + +1. **Global default TTL**: Applies to all agent sessions unless overridden +2. **Per-agent type TTL**: Overrides the global default for specific agent types + +Additionally, you can configure a **minimum deletion delay** that controls how frequently deletion operations are scheduled. The default value is 5 minutes, and the maximum allowed value is also 5 minutes. + +> [!NOTE] +> Reducing the minimum deletion delay below 5 minutes can be useful for testing or for ensuring rapid cleanup of short-lived agent sessions. However, this can also increase the load on the system and should be used with caution. + +### Default values + +- **Default TTL**: 14 days +- **Minimum TTL deletion delay**: 5 minutes (maximum allowed value, subject to change in future releases) + +### Configuration examples + +#### .NET + +```csharp +// Configure global default TTL and minimum signal delay +services.ConfigureDurableAgents( + options => + { + // Set global default TTL to 7 days + options.DefaultTimeToLive = TimeSpan.FromDays(7); + + // Add agents (will use global default TTL) + options.AddAIAgent(myAgent); + }); + +// Configure per-agent TTL +services.ConfigureDurableAgents( + options => + { + options.DefaultTimeToLive = TimeSpan.FromDays(14); // Global default + + // Agent with custom TTL of 1 day + options.AddAIAgent(shortLivedAgent, timeToLive: TimeSpan.FromDays(1)); + + // Agent with custom TTL of 90 days + options.AddAIAgent(longLivedAgent, timeToLive: TimeSpan.FromDays(90)); + + // Agent using global default (14 days) + options.AddAIAgent(defaultAgent); + }); + +// Disable TTL for specific agents by setting TTL to null +services.ConfigureDurableAgents( + options => + { + options.DefaultTimeToLive = TimeSpan.FromDays(14); + + // Agent with no TTL (never expires) + options.AddAIAgent(permanentAgent, timeToLive: null); + }); +``` + +## How TTL works + +The following sections describe how TTL works in detail. + +### Expiration tracking + +Each agent session maintains an expiration timestamp in its internally managed state that is updated whenever the session processes a message: + +1. When a message is sent to an agent session, the expiration time is set to `current time + TTL` +2. The runtime schedules a delete operation for the expiration time (subject to minimum delay constraints) +3. When the delete operation runs, if the current time is past the expiration time, the session state is deleted. Otherwise, the delete operation is rescheduled for the next expiration time. + +### State deletion + +When an agent session expires, its entire state is deleted, including: + +- Conversation history +- Any custom state data +- Expiration timestamps + +After deletion, if a message is sent to the same agent session, a new session is created with a fresh conversation history. + +## Behavior examples + +The following examples illustrate how TTL works in different scenarios. + +### Example 1: Agent session expires after TTL + +1. Agent configured with 30-day TTL +2. User sends message at Day 0 → agent session created, expiration set to Day 30 +3. No further messages sent +4. At Day 30 → Agent session is deleted +5. User sends message at Day 31 → New agent session created with fresh conversation history + +### Example 2: TTL reset on interaction + +1. Agent configured with 30-day TTL +2. User sends message at Day 0 → agent session created, expiration set to Day 30 +3. User sends message at Day 15 → Expiration reset to Day 45 +4. User sends message at Day 40 → Expiration reset to Day 70 +5. Agent session remains active as long as there are regular interactions + +## Logging + +The TTL feature includes comprehensive logging to track state changes: + +- **Expiration time updated**: Logged when TTL expiration time is set or updated +- **Deletion scheduled**: Logged when a deletion check signal is scheduled +- **Deletion check**: Logged when a deletion check operation runs +- **Session expired**: Logged when an agent session is deleted due to expiration +- **TTL rescheduled**: Logged when a deletion signal is rescheduled + +These logs help monitor TTL behavior and troubleshoot any issues. + +## Best practices + +1. **Choose appropriate TTL values**: Balance between storage costs and user experience. Too short TTLs may delete active sessions, while too long TTLs may accumulate unnecessary state. + +2. **Use per-agent TTLs**: Different agents may have different usage patterns. Configure TTLs per-agent based on expected session lifetimes. + +3. **Monitor expiration logs**: Review logs to understand TTL behavior and adjust configuration as needed. + +4. **Test with short TTLs**: During development, use short TTLs (e.g., minutes) to verify TTL behavior without waiting for long periods. + +## Limitations + +- TTL is based on wall-clock time, not activity time. The expiration timer starts from the last message timestamp. +- Deletion checks are durably scheduled operations and may have slight delays depending on system load. +- Once an agent session is deleted, its conversation history cannot be recovered. +- TTL deletion requires at least one worker to be available to process the deletion operation message. diff --git a/dotnet/.editorconfig b/dotnet/.editorconfig index c0d0d04fe9..fea0183976 100644 --- a/dotnet/.editorconfig +++ b/dotnet/.editorconfig @@ -209,6 +209,7 @@ dotnet_diagnostic.CA2000.severity = none # Call System.IDisposable.Dispose on ob dotnet_diagnostic.CA2225.severity = none # Operator overloads have named alternates dotnet_diagnostic.CA2227.severity = none # Change to be read-only by removing the property setter dotnet_diagnostic.CA2249.severity = suggestion # Consider using 'Contains' method instead of 'IndexOf' method +dotnet_diagnostic.CA2252.severity = none # Requires preview dotnet_diagnostic.CA2253.severity = none # Named placeholders in the logging message template should not be comprised of only numeric characters dotnet_diagnostic.CA2253.severity = none # Named placeholders in the logging message template should not be comprised of only numeric characters dotnet_diagnostic.CA2263.severity = suggestion # Use generic overload diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 89be7f3520..4825a42921 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -19,11 +19,11 @@ - - + + - - + + @@ -61,10 +61,9 @@ - - - - + + + @@ -101,11 +100,10 @@ - - + - + diff --git a/dotnet/agent-framework-dotnet.slnx b/dotnet/agent-framework-dotnet.slnx index ec197f58c0..5e08a766f9 100644 --- a/dotnet/agent-framework-dotnet.slnx +++ b/dotnet/agent-framework-dotnet.slnx @@ -129,6 +129,7 @@ + diff --git a/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/AgentWebChat.AgentHost.csproj b/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/AgentWebChat.AgentHost.csproj index 3f2a832a69..f71becf5d3 100644 --- a/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/AgentWebChat.AgentHost.csproj +++ b/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/AgentWebChat.AgentHost.csproj @@ -25,7 +25,6 @@ - diff --git a/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/Utilities/ChatClientExtensions.cs b/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/Utilities/ChatClientExtensions.cs index 6cd3d888c8..7b1f2d86b4 100644 --- a/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/Utilities/ChatClientExtensions.cs +++ b/dotnet/samples/AgentWebChat/AgentWebChat.AgentHost/Utilities/ChatClientExtensions.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using AgentWebChat.AgentHost.Utilities; -using Azure; -using Azure.AI.Inference; using Microsoft.Extensions.AI; using OllamaSharp; @@ -24,7 +22,6 @@ public static ChatClientBuilder AddChatClient(this IHostApplicationBuilder build ClientChatProvider.Ollama => builder.AddOllamaClient(connectionName, connectionInfo), ClientChatProvider.OpenAI => builder.AddOpenAIClient(connectionName, connectionInfo), ClientChatProvider.AzureOpenAI => builder.AddAzureOpenAIClient(connectionName).AddChatClient(connectionInfo.SelectedModel), - ClientChatProvider.AzureAIInference => builder.AddAzureInferenceClient(connectionName, connectionInfo), _ => throw new NotSupportedException($"Unsupported provider: {connectionInfo.Provider}") }; @@ -44,16 +41,6 @@ private static ChatClientBuilder AddOpenAIClient(this IHostApplicationBuilder bu }) .AddChatClient(connectionInfo.SelectedModel); - private static ChatClientBuilder AddAzureInferenceClient(this IHostApplicationBuilder builder, string connectionName, ChatClientConnectionInfo connectionInfo) => - builder.Services.AddChatClient(sp => - { - var credential = new AzureKeyCredential(connectionInfo.AccessKey!); - - var client = new ChatCompletionsClient(connectionInfo.Endpoint, credential, new AzureAIInferenceClientOptions()); - - return client.AsIChatClient(connectionInfo.SelectedModel); - }); - private static ChatClientBuilder AddOllamaClient(this IHostApplicationBuilder builder, string connectionName, ChatClientConnectionInfo connectionInfo) { var httpKey = $"{connectionName}_http"; @@ -83,7 +70,6 @@ public static ChatClientBuilder AddKeyedChatClient(this IHostApplicationBuilder ClientChatProvider.Ollama => builder.AddKeyedOllamaClient(connectionName, connectionInfo), ClientChatProvider.OpenAI => builder.AddKeyedOpenAIClient(connectionName, connectionInfo), ClientChatProvider.AzureOpenAI => builder.AddKeyedAzureOpenAIClient(connectionName).AddKeyedChatClient(connectionName, connectionInfo.SelectedModel), - ClientChatProvider.AzureAIInference => builder.AddKeyedAzureInferenceClient(connectionName, connectionInfo), _ => throw new NotSupportedException($"Unsupported provider: {connectionInfo.Provider}") }; @@ -103,16 +89,6 @@ private static ChatClientBuilder AddKeyedOpenAIClient(this IHostApplicationBuild }) .AddKeyedChatClient(connectionName, connectionInfo.SelectedModel); - private static ChatClientBuilder AddKeyedAzureInferenceClient(this IHostApplicationBuilder builder, string connectionName, ChatClientConnectionInfo connectionInfo) => - builder.Services.AddKeyedChatClient(connectionName, sp => - { - var credential = new AzureKeyCredential(connectionInfo.AccessKey!); - - var client = new ChatCompletionsClient(connectionInfo.Endpoint, credential, new AzureAIInferenceClientOptions()); - - return client.AsIChatClient(connectionInfo.SelectedModel); - }); - private static ChatClientBuilder AddKeyedOllamaClient(this IHostApplicationBuilder builder, string connectionName, ChatClientConnectionInfo connectionInfo) { var httpKey = $"{connectionName}_http"; diff --git a/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs b/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs index 7cc85b97c3..d0121a6165 100644 --- a/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs +++ b/dotnet/samples/AgentWebChat/AgentWebChat.Web/OpenAIResponsesAgentClient.cs @@ -27,7 +27,7 @@ public override async IAsyncEnumerable RunStreamingAsync Transport = new HttpClientPipelineTransport(httpClient) }; - var openAiClient = new OpenAIResponseClient(model: agentName, credential: new ApiKeyCredential("dummy-key"), options: options).AsIChatClient(); + var openAiClient = new ResponsesClient(model: agentName, credential: new ApiKeyCredential("dummy-key"), options: options).AsIChatClient(); var chatOptions = new ChatOptions() { ConversationId = threadId diff --git a/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs b/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs index b3d40a120c..39b020e137 100644 --- a/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs +++ b/dotnet/samples/AzureFunctions/01_SingleAgent/Program.cs @@ -32,6 +32,6 @@ using IHost app = FunctionsApplication .CreateBuilder(args) .ConfigureFunctionsWebApplication() - .ConfigureDurableAgents(options => options.AddAIAgent(agent)) + .ConfigureDurableAgents(options => options.AddAIAgent(agent, timeToLive: TimeSpan.FromHours(1))) .Build(); app.Run(); diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs index 83d5619382..5ce85b2b91 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_AzureOpenAIResponses/Program.cs @@ -13,7 +13,7 @@ AIAgent agent = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs index df53ba8869..b0d0285928 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_OpenAIResponses/Program.cs @@ -11,7 +11,7 @@ AIAgent agent = new OpenAIClient( apiKey) - .GetOpenAIResponseClient(model) + .GetResponsesClient(model) .CreateAIAgent(instructions: "You are good at telling jokes.", name: "Joker"); // Invoke the agent and output the text result. diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step02_Reasoning/Program.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step02_Reasoning/Program.cs index e06a8cc76f..aa18fdd286 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step02_Reasoning/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step02_Reasoning/Program.cs @@ -11,11 +11,11 @@ var model = Environment.GetEnvironmentVariable("OPENAI_MODEL") ?? "gpt-5"; var client = new OpenAIClient(apiKey) - .GetOpenAIResponseClient(model) + .GetResponsesClient(model) .AsIChatClient().AsBuilder() .ConfigureOptions(o => { - o.RawRepresentationFactory = _ => new ResponseCreationOptions() + o.RawRepresentationFactory = _ => new CreateResponseOptions() { ReasoningOptions = new() { diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs index 456de02836..622223307c 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/OpenAIResponseClientAgent.cs @@ -16,13 +16,13 @@ public class OpenAIResponseClientAgent : DelegatingAIAgent /// /// Initialize an instance of . /// - /// Instance of + /// Instance of /// Optional instructions for the agent. /// Optional name for the agent. /// Optional description for the agent. /// Optional instance of public OpenAIResponseClientAgent( - OpenAIResponseClient client, + ResponsesClient client, string? instructions = null, string? name = null, string? description = null, @@ -39,11 +39,11 @@ public OpenAIResponseClientAgent( /// /// Initialize an instance of . /// - /// Instance of + /// Instance of /// Options to create the agent. /// Optional instance of public OpenAIResponseClientAgent( - OpenAIResponseClient client, ChatClientAgentOptions options, ILoggerFactory? loggerFactory = null) : + ResponsesClient client, ChatClientAgentOptions options, ILoggerFactory? loggerFactory = null) : base(new ChatClientAgent((client ?? throw new ArgumentNullException(nameof(client))).AsIChatClient(), options, loggerFactory)) { } @@ -55,8 +55,8 @@ public OpenAIResponseClientAgent( /// The conversation thread to continue with this invocation. If not provided, creates a new thread. The thread will be mutated with the provided messages and agent response. /// Optional parameters for agent invocation. /// The to monitor for cancellation requests. The default is . - /// A containing the list of items. - public virtual async Task RunAsync( + /// A containing the list of items. + public virtual async Task RunAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, @@ -74,7 +74,7 @@ public virtual async Task RunAsync( /// The conversation thread to continue with this invocation. If not provided, creates a new thread. The thread will be mutated with the provided messages and agent response. /// Optional parameters for agent invocation. /// The to monitor for cancellation requests. The default is . - /// A containing the list of items. + /// A containing the list of items. public virtual async IAsyncEnumerable RunStreamingAsync( IEnumerable messages, AgentThread? thread = null, diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/Program.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/Program.cs index 89a96bc0fb..5c229cc57d 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/Program.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -// This sample demonstrates how to create OpenAIResponseClientAgent directly from an OpenAIResponseClient instance. +// This sample demonstrates how to create OpenAIResponseClientAgent directly from an ResponsesClient instance. using OpenAI; using OpenAI.Responses; @@ -9,16 +9,16 @@ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("OPENAI_API_KEY is not set."); var model = Environment.GetEnvironmentVariable("OPENAI_MODEL") ?? "gpt-4o-mini"; -// Create an OpenAIResponseClient directly from OpenAIClient -OpenAIResponseClient responseClient = new OpenAIClient(apiKey).GetOpenAIResponseClient(model); +// Create a ResponsesClient directly from OpenAIClient +ResponsesClient responseClient = new OpenAIClient(apiKey).GetResponsesClient(model); -// Create an agent directly from the OpenAIResponseClient using OpenAIResponseClientAgent +// Create an agent directly from the ResponsesClient using OpenAIResponseClientAgent OpenAIResponseClientAgent agent = new(responseClient, instructions: "You are good at telling jokes.", name: "Joker"); ResponseItem userMessage = ResponseItem.CreateUserMessageItem("Tell me a joke about a pirate."); // Invoke the agent and output the text result. -OpenAIResponse response = await agent.RunAsync([userMessage]); +ResponseResult response = await agent.RunAsync([userMessage]); Console.WriteLine(response.GetOutputText()); // Invoke the agent with streaming support. diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Agent_OpenAI_Step05_Conversation.csproj b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Agent_OpenAI_Step05_Conversation.csproj new file mode 100644 index 0000000000..eeda3eef6f --- /dev/null +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Agent_OpenAI_Step05_Conversation.csproj @@ -0,0 +1,15 @@ + + + + Exe + net10.0 + + enable + enable + + + + + + + diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Program.cs b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Program.cs new file mode 100644 index 0000000000..8aebebdfa0 --- /dev/null +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/Program.cs @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft. All rights reserved. + +// This sample demonstrates how to maintain conversation state using the OpenAIResponseClientAgent +// and AgentThread. By passing the same thread to multiple agent invocations, the agent +// automatically maintains the conversation history, allowing the AI model to understand +// context from previous exchanges. + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Text.Json; +using Microsoft.Agents.AI; +using Microsoft.Extensions.AI; +using OpenAI; +using OpenAI.Chat; +using OpenAI.Conversations; + +string apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("OPENAI_API_KEY is not set."); +string model = Environment.GetEnvironmentVariable("OPENAI_MODEL") ?? "gpt-4o-mini"; + +// Create a ConversationClient directly from OpenAIClient +OpenAIClient openAIClient = new(apiKey); +ConversationClient conversationClient = openAIClient.GetConversationClient(); + +// Create an agent directly from the ResponsesClient using OpenAIResponseClientAgent +ChatClientAgent agent = new(openAIClient.GetResponsesClient(model).AsIChatClient(), instructions: "You are a helpful assistant.", name: "ConversationAgent"); + +ClientResult createConversationResult = await conversationClient.CreateConversationAsync(BinaryContent.Create(BinaryData.FromString("{}"))); + +using JsonDocument createConversationResultAsJson = JsonDocument.Parse(createConversationResult.GetRawResponse().Content.ToString()); +string conversationId = createConversationResultAsJson.RootElement.GetProperty("id"u8)!.GetString()!; + +// Create a thread for the conversation - this enables conversation state management for subsequent turns +AgentThread thread = agent.GetNewThread(conversationId); + +Console.WriteLine("=== Multi-turn Conversation Demo ===\n"); + +// First turn: Ask about a topic +Console.WriteLine("User: What is the capital of France?"); +UserChatMessage firstMessage = new("What is the capital of France?"); + +// After this call, the conversation state associated in the options is stored in 'thread' and used in subsequent calls +ChatCompletion firstResponse = await agent.RunAsync([firstMessage], thread); +Console.WriteLine($"Assistant: {firstResponse.Content.Last().Text}\n"); + +// Second turn: Follow-up question that relies on conversation context +Console.WriteLine("User: What famous landmarks are located there?"); +UserChatMessage secondMessage = new("What famous landmarks are located there?"); + +ChatCompletion secondResponse = await agent.RunAsync([secondMessage], thread); +Console.WriteLine($"Assistant: {secondResponse.Content.Last().Text}\n"); + +// Third turn: Another follow-up that demonstrates context continuity +Console.WriteLine("User: How tall is the most famous one?"); +UserChatMessage thirdMessage = new("How tall is the most famous one?"); + +ChatCompletion thirdResponse = await agent.RunAsync([thirdMessage], thread); +Console.WriteLine($"Assistant: {thirdResponse.Content.Last().Text}\n"); + +Console.WriteLine("=== End of Conversation ==="); + +// Show full conversation history +Console.WriteLine("Full Conversation History:"); +ClientResult getConversationResult = await conversationClient.GetConversationAsync(conversationId); + +Console.WriteLine("Conversation created."); +Console.WriteLine($" Conversation ID: {conversationId}"); +Console.WriteLine(); + +CollectionResult getConversationItemsResults = conversationClient.GetConversationItems(conversationId); +foreach (ClientResult result in getConversationItemsResults.GetRawPages()) +{ + Console.WriteLine("Message contents retrieved. Order is most recent first by default."); + using JsonDocument getConversationItemsResultAsJson = JsonDocument.Parse(result.GetRawResponse().Content.ToString()); + foreach (JsonElement element in getConversationItemsResultAsJson.RootElement.GetProperty("data").EnumerateArray()) + { + string messageId = element.GetProperty("id"u8).ToString(); + string messageRole = element.GetProperty("role"u8).ToString(); + Console.WriteLine($" Message ID: {messageId}"); + Console.WriteLine($" Message Role: {messageRole}"); + + foreach (var content in element.GetProperty("content").EnumerateArray()) + { + string messageContentText = content.GetProperty("text"u8).ToString(); + Console.WriteLine($" Message Text: {messageContentText}"); + } + Console.WriteLine(); + } +} + +ClientResult deleteConversationResult = conversationClient.DeleteConversation(conversationId); +using JsonDocument deleteConversationResultAsJson = JsonDocument.Parse(deleteConversationResult.GetRawResponse().Content.ToString()); +bool deleted = deleteConversationResultAsJson.RootElement + .GetProperty("deleted"u8) + .GetBoolean(); + +Console.WriteLine("Conversation deleted."); +Console.WriteLine($" Deleted: {deleted}"); +Console.WriteLine(); diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/README.md b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/README.md new file mode 100644 index 0000000000..c279ba2c17 --- /dev/null +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/Agent_OpenAI_Step05_Conversation/README.md @@ -0,0 +1,90 @@ +# Managing Conversation State with OpenAI + +This sample demonstrates how to maintain conversation state across multiple turns using the Agent Framework with OpenAI's Conversation API. + +## What This Sample Shows + +- **Conversation State Management**: Shows how to use `ConversationClient` and `AgentThread` to maintain conversation context across multiple agent invocations +- **Multi-turn Conversations**: Demonstrates follow-up questions that rely on context from previous messages in the conversation +- **Server-Side Storage**: Uses OpenAI's Conversation API to manage conversation history server-side, allowing the model to access previous messages without resending them +- **Conversation Lifecycle**: Demonstrates creating, retrieving, and deleting conversations + +## Key Concepts + +### ConversationClient for Server-Side Storage + +The `ConversationClient` manages conversations on OpenAI's servers: + +```csharp +// Create a ConversationClient from OpenAIClient +OpenAIClient openAIClient = new(apiKey); +ConversationClient conversationClient = openAIClient.GetConversationClient(); + +// Create a new conversation +ClientResult createConversationResult = await conversationClient.CreateConversationAsync(BinaryContent.Create(BinaryData.FromString("{}"))); +``` + +### AgentThread for Conversation State + +The `AgentThread` works with `ChatClientAgentRunOptions` to link the agent to a server-side conversation: + +```csharp +// Set up agent run options with the conversation ID +ChatClientAgentRunOptions agentRunOptions = new() { ChatOptions = new ChatOptions() { ConversationId = conversationId } }; + +// Create a thread for the conversation +AgentThread thread = agent.GetNewThread(); + +// First call links the thread to the conversation +ChatCompletion firstResponse = await agent.RunAsync([firstMessage], thread, agentRunOptions); + +// Subsequent calls use the thread without needing to pass options again +ChatCompletion secondResponse = await agent.RunAsync([secondMessage], thread); +``` + +### Retrieving Conversation History + +You can retrieve the full conversation history from the server: + +```csharp +CollectionResult getConversationItemsResults = conversationClient.GetConversationItems(conversationId); +foreach (ClientResult result in getConversationItemsResults.GetRawPages()) +{ + // Process conversation items +} +``` + +### How It Works + +1. **Create an OpenAI Client**: Initialize an `OpenAIClient` with your API key +2. **Create a Conversation**: Use `ConversationClient` to create a server-side conversation +3. **Create an Agent**: Initialize an `OpenAIResponseClientAgent` with the desired model and instructions +4. **Create a Thread**: Call `agent.GetNewThread()` to create a new conversation thread +5. **Link Thread to Conversation**: Pass `ChatClientAgentRunOptions` with the `ConversationId` on the first call +6. **Send Messages**: Subsequent calls to `agent.RunAsync()` only need the thread - context is maintained +7. **Cleanup**: Delete the conversation when done using `conversationClient.DeleteConversation()` + +## Running the Sample + +1. Set the required environment variables: + ```powershell + $env:OPENAI_API_KEY = "your_api_key_here" + $env:OPENAI_MODEL = "gpt-4o-mini" + ``` + +2. Run the sample: + ```powershell + dotnet run + ``` + +## Expected Output + +The sample demonstrates a three-turn conversation where each follow-up question relies on context from previous messages: + +1. First question asks about the capital of France +2. Second question asks about landmarks "there" - requiring understanding of the previous answer +3. Third question asks about "the most famous one" - requiring context from both previous turns + +After the conversation, the sample retrieves and displays the full conversation history from the server, then cleans up by deleting the conversation. + +This demonstrates that the conversation state is properly maintained across multiple agent invocations using OpenAI's server-side conversation storage. diff --git a/dotnet/samples/GettingStarted/AgentWithOpenAI/README.md b/dotnet/samples/GettingStarted/AgentWithOpenAI/README.md index 6f2c77f39b..019af7f2b6 100644 --- a/dotnet/samples/GettingStarted/AgentWithOpenAI/README.md +++ b/dotnet/samples/GettingStarted/AgentWithOpenAI/README.md @@ -13,4 +13,5 @@ Agent Framework provides additional support to allow OpenAI developers to use th |[Creating an AIAgent](./Agent_OpenAI_Step01_Running/)|This sample demonstrates how to create and run a basic agent with native OpenAI SDK types. Shows both regular and streaming invocation of the agent.| |[Using Reasoning Capabilities](./Agent_OpenAI_Step02_Reasoning/)|This sample demonstrates how to create an AI agent with reasoning capabilities using OpenAI's reasoning models and response types.| |[Creating an Agent from a ChatClient](./Agent_OpenAI_Step03_CreateFromChatClient/)|This sample demonstrates how to create an AI agent directly from an OpenAI.Chat.ChatClient instance using OpenAIChatClientAgent.| -|[Creating an Agent from an OpenAIResponseClient](./Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/)|This sample demonstrates how to create an AI agent directly from an OpenAI.Responses.OpenAIResponseClient instance using OpenAIResponseClientAgent.| \ No newline at end of file +|[Creating an Agent from an OpenAIResponseClient](./Agent_OpenAI_Step04_CreateFromOpenAIResponseClient/)|This sample demonstrates how to create an AI agent directly from an OpenAI.Responses.OpenAIResponseClient instance using OpenAIResponseClientAgent.| +|[Managing Conversation State](./Agent_OpenAI_Step05_Conversation/)|This sample demonstrates how to maintain conversation state across multiple turns using the AgentThread for context continuity.| \ No newline at end of file diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs index 41493f6d79..29dc347b4a 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step13_BackgroundResponsesWithToolsAndPersistence/Program.cs @@ -22,7 +22,7 @@ AIAgent agent = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .CreateAIAgent( name: "SpaceNovelWriter", instructions: "You are a space novel writer. Always research relevant facts and generate character profiles for the main characters before writing novels." + diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs index 510a5dfbd0..3e172a95b5 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step17_BackgroundResponses/Program.cs @@ -13,7 +13,7 @@ AIAgent agent = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .CreateAIAgent(); // Enable background responses (only supported by OpenAI Responses at this time). diff --git a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs index 05fb39bbf4..ff4f57924a 100644 --- a/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs +++ b/dotnet/samples/GettingStarted/FoundryAgents/FoundryAgents_Step15_ComputerUse/Program.cs @@ -73,7 +73,7 @@ private static async Task InvokeComputerUseAgentAsync(AIAgent agent) Dictionary screenshots = ComputerUseUtil.LoadScreenshotAssets(); ChatOptions chatOptions = new(); - ResponseCreationOptions responseCreationOptions = new() + CreateResponseOptions responseCreationOptions = new() { TruncationMode = ResponseTruncationMode.Auto }; diff --git a/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs b/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs index ba4249c765..13ee28d6a1 100644 --- a/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs +++ b/dotnet/samples/GettingStarted/ModelContextProtocol/ResponseAgent_Hosted_MCP/Program.cs @@ -30,7 +30,7 @@ AIAgent agent = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .CreateAIAgent( instructions: "You answer questions by searching the Microsoft Learn content only.", name: "MicrosoftLearnAgent", @@ -57,7 +57,7 @@ AIAgent agentWithRequiredApproval = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .CreateAIAgent( instructions: "You answer questions by searching the Microsoft Learn content only.", name: "MicrosoftLearnAgentWithApproval", diff --git a/dotnet/samples/HostedAgents/AgentWithHostedMCP/AgentWithHostedMCP.csproj b/dotnet/samples/HostedAgents/AgentWithHostedMCP/AgentWithHostedMCP.csproj index 21146dd2dd..d2c0ea70f8 100644 --- a/dotnet/samples/HostedAgents/AgentWithHostedMCP/AgentWithHostedMCP.csproj +++ b/dotnet/samples/HostedAgents/AgentWithHostedMCP/AgentWithHostedMCP.csproj @@ -36,9 +36,9 @@ - + - + diff --git a/dotnet/samples/HostedAgents/AgentWithTextSearchRag/AgentWithTextSearchRag.csproj b/dotnet/samples/HostedAgents/AgentWithTextSearchRag/AgentWithTextSearchRag.csproj index 118b92e074..e67846f54c 100644 --- a/dotnet/samples/HostedAgents/AgentWithTextSearchRag/AgentWithTextSearchRag.csproj +++ b/dotnet/samples/HostedAgents/AgentWithTextSearchRag/AgentWithTextSearchRag.csproj @@ -37,7 +37,7 @@ - + diff --git a/dotnet/samples/HostedAgents/AgentsInWorkflows/AgentsInWorkflows.csproj b/dotnet/samples/HostedAgents/AgentsInWorkflows/AgentsInWorkflows.csproj index c0fca8a340..a865f43be5 100644 --- a/dotnet/samples/HostedAgents/AgentsInWorkflows/AgentsInWorkflows.csproj +++ b/dotnet/samples/HostedAgents/AgentsInWorkflows/AgentsInWorkflows.csproj @@ -37,7 +37,7 @@ - + diff --git a/dotnet/samples/Purview/AgentWithPurview/Program.cs b/dotnet/samples/Purview/AgentWithPurview/Program.cs index 842917b427..a4b27c47cd 100644 --- a/dotnet/samples/Purview/AgentWithPurview/Program.cs +++ b/dotnet/samples/Purview/AgentWithPurview/Program.cs @@ -27,7 +27,7 @@ using IChatClient client = new AzureOpenAIClient( new Uri(endpoint), new AzureCliCredential()) - .GetOpenAIResponseClient(deploymentName) + .GetResponsesClient(deploymentName) .AsIChatClient() .AsBuilder() .WithPurview(browserCredential, new PurviewSettings("Agent Framework Test App")) diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs index bdb1f72928..e804fbb389 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs @@ -198,7 +198,7 @@ public override async IAsyncEnumerable RunStreamingAsync } /// - public override string Id => this._id ?? base.Id; + protected override string? IdCore => this._id; /// public override string? Name => this._name ?? base.Name; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs index eba6f84687..4cff385dcc 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs @@ -22,9 +22,6 @@ namespace Microsoft.Agents.AI; [DebuggerDisplay("{DisplayName,nq}")] public abstract class AIAgent { - /// Default ID of this agent instance. - private readonly string _id = Guid.NewGuid().ToString("N"); - /// /// Gets the unique identifier for this agent instance. /// @@ -37,7 +34,19 @@ public abstract class AIAgent /// agent instances in multi-agent scenarios. They should remain stable for the lifetime /// of the agent instance. /// - public virtual string Id => this._id; + public string Id { get => this.IdCore ?? field; } = Guid.NewGuid().ToString("N"); + + /// + /// Gets a custom identifier for the agent, which can be overridden by derived classes. + /// + /// + /// A string representing the agent's identifier, or if the default ID should be used. + /// + /// + /// Derived classes can override this property to provide a custom identifier. + /// When is returned, the property will use the default randomly-generated identifier. + /// + protected virtual string? IdCore => null; /// /// Gets the human-readable name of the agent. @@ -61,7 +70,7 @@ public abstract class AIAgent /// This property provides a guaranteed non-null string suitable for display in user interfaces, /// logs, or other contexts where a readable identifier is needed. /// - public virtual string DisplayName => this.Name ?? this.Id ?? this._id; // final fallback to _id in case Id override returns null + public virtual string DisplayName => this.Name ?? this.Id; /// /// Gets a description of the agent's purpose, capabilities, or behavior. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs index 353c82c996..4c0ff1a36d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs @@ -25,7 +25,7 @@ namespace Microsoft.Agents.AI; /// Derived classes can override specific methods to add custom behavior while maintaining compatibility with the agent interface. /// /// -public class DelegatingAIAgent : AIAgent +public abstract class DelegatingAIAgent : AIAgent { /// /// Initializes a new instance of the class with the specified inner agent. @@ -54,7 +54,7 @@ protected DelegatingAIAgent(AIAgent innerAgent) protected AIAgent InnerAgent { get; } /// - public override string Id => this.InnerAgent.Id; + protected override string? IdCore => this.InnerAgent.Id; /// public override string? Name => this.InnerAgent.Name; diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClient.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClient.cs index 8acafc8fc3..f31c570508 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClient.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClient.cs @@ -23,11 +23,6 @@ internal sealed class AzureAIProjectChatClient : DelegatingChatClient private readonly AgentRecord? _agentRecord; private readonly ChatOptions? _chatOptions; private readonly AgentReference _agentReference; - /// - /// The usage of a no-op model is a necessary change to avoid OpenAIClients to throw exceptions when - /// used with Azure AI Agents as the model used is now defined at the agent creation time. - /// - private const string NoOpModel = "no-op"; /// /// Initializes a new instance of the class. @@ -42,7 +37,7 @@ internal sealed class AzureAIProjectChatClient : DelegatingChatClient internal AzureAIProjectChatClient(AIProjectClient aiProjectClient, AgentReference agentReference, string? defaultModelId, ChatOptions? chatOptions) : base(Throw.IfNull(aiProjectClient) .GetProjectOpenAIClient() - .GetOpenAIResponseClient(defaultModelId ?? NoOpModel) + .GetProjectResponsesClientForAgent(agentReference) .AsIChatClient()) { this._agentClient = aiProjectClient; @@ -132,13 +127,15 @@ private ChatOptions GetAgentEnabledChatOptions(ChatOptions? options) agentEnabledChatOptions.RawRepresentationFactory = (client) => { - if (originalFactory?.Invoke(this) is not ResponseCreationOptions responseCreationOptions) + if (originalFactory?.Invoke(this) is not CreateResponseOptions responseCreationOptions) { - responseCreationOptions = new ResponseCreationOptions(); + responseCreationOptions = new CreateResponseOptions(); } - ResponseCreationOptionsExtensions.set_Agent(responseCreationOptions, this._agentReference); - ResponseCreationOptionsExtensions.set_Model(responseCreationOptions, null); + responseCreationOptions.Agent = this._agentReference; +#pragma warning disable SCME0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + responseCreationOptions.Patch.Remove("$.model"u8); +#pragma warning restore SCME0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. return responseCreationOptions; }; diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs index dfbdad8e98..7319bb13eb 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs @@ -400,7 +400,7 @@ public static ChatClientAgent CreateAIAgent( }; // Attempt to capture breaking glass options from the raw representation factory that match the agent definition. - if (options.ChatOptions?.RawRepresentationFactory?.Invoke(new NoOpChatClient()) is ResponseCreationOptions respCreationOptions) + if (options.ChatOptions?.RawRepresentationFactory?.Invoke(new NoOpChatClient()) is CreateResponseOptions respCreationOptions) { agentDefinition.ReasoningOptions = respCreationOptions.ReasoningOptions; } @@ -466,7 +466,7 @@ public static async Task CreateAIAgentAsync( }; // Attempt to capture breaking glass options from the raw representation factory that match the agent definition. - if (options.ChatOptions?.RawRepresentationFactory?.Invoke(new NoOpChatClient()) is ResponseCreationOptions respCreationOptions) + if (options.ChatOptions?.RawRepresentationFactory?.Invoke(new NoOpChatClient()) is CreateResponseOptions respCreationOptions) { agentDefinition.ReasoningOptions = respCreationOptions.ReasoningOptions; } diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosCheckpointStore.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosCheckpointStore.cs index 62987b1dfc..e0073feaf9 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosCheckpointStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosCheckpointStore.cs @@ -217,9 +217,7 @@ protected virtual void Dispose(bool disposing) } } - /// - /// Represents a checkpoint document stored in Cosmos DB. - /// + /// Represents a checkpoint document stored in Cosmos DB. internal sealed class CosmosCheckpointDocument { [JsonProperty("id")] diff --git a/dotnet/src/Microsoft.Agents.AI.DevUI/DevUIMiddleware.cs b/dotnet/src/Microsoft.Agents.AI.DevUI/DevUIMiddleware.cs index 2d6fcbd7e4..ac585ad39a 100644 --- a/dotnet/src/Microsoft.Agents.AI.DevUI/DevUIMiddleware.cs +++ b/dotnet/src/Microsoft.Agents.AI.DevUI/DevUIMiddleware.cs @@ -81,7 +81,7 @@ public async Task HandleRequestAsync(HttpContext context) } context.Response.StatusCode = StatusCodes.Status301MovedPermanently; - context.Response.Headers.Location = redirectUrl; + context.Response.Headers.Location = redirectUrl; // CodeQL [SM04598] justification: The redirect URL is constructed from a server-configured base path (_basePath), not user input. The query string is only appended as parameters and cannot change the redirect destination since this is a relative URL. if (this._logger.IsEnabled(LogLevel.Debug)) { diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs index 166799a124..ec4ba3acf6 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs @@ -16,29 +16,34 @@ internal class AgentEntity(IServiceProvider services, CancellationToken cancella private readonly DurableTaskClient _client = services.GetRequiredService(); private readonly ILoggerFactory _loggerFactory = services.GetRequiredService(); private readonly IAgentResponseHandler? _messageHandler = services.GetService(); + private readonly DurableAgentsOptions _options = services.GetRequiredService(); private readonly CancellationToken _cancellationToken = cancellationToken != default ? cancellationToken : services.GetService()?.ApplicationStopping ?? CancellationToken.None; - public async Task RunAgentAsync(RunRequest request) + public Task RunAgentAsync(RunRequest request) { - AgentSessionId sessionId = this.Context.Id; - IReadOnlyDictionary> agents = - this._services.GetRequiredService>>(); - if (!agents.TryGetValue(sessionId.Name, out Func? agentFactory)) - { - throw new InvalidOperationException($"Agent '{sessionId.Name}' not found"); - } + return this.Run(request); + } - AIAgent agent = agentFactory(this._services); + // IDE1006 and VSTHRD200 disabled to allow method name to match the common cross-platform entity operation name. +#pragma warning disable IDE1006 +#pragma warning disable VSTHRD200 + public async Task Run(RunRequest request) +#pragma warning restore VSTHRD200 +#pragma warning restore IDE1006 + { + AgentSessionId sessionId = this.Context.Id; + AIAgent agent = this.GetAgent(sessionId); EntityAgentWrapper agentWrapper = new(agent, this.Context, request, this._services); // Logger category is Microsoft.DurableTask.Agents.{agentName}.{sessionId} - ILogger logger = this._loggerFactory.CreateLogger($"Microsoft.DurableTask.Agents.{agent.Name}.{sessionId.Key}"); + ILogger logger = this.GetLogger(agent.Name!, sessionId.Key); if (request.Messages.Count == 0) { logger.LogInformation("Ignoring empty request"); + return new AgentRunResponse(); } this.State.Data.ConversationHistory.Add(DurableAgentStateRequest.FromRunRequest(request)); @@ -113,6 +118,36 @@ async IAsyncEnumerable StreamResultsAsync() response.Usage?.TotalTokenCount); } + // Update TTL expiration time. Only schedule deletion check on first interaction. + // Subsequent interactions just update the expiration time; CheckAndDeleteIfExpiredAsync + // will reschedule the deletion check when it runs. + TimeSpan? timeToLive = this._options.GetTimeToLive(sessionId.Name); + if (timeToLive.HasValue) + { + DateTime newExpirationTime = DateTime.UtcNow.Add(timeToLive.Value); + bool isFirstInteraction = this.State.Data.ExpirationTimeUtc is null; + + this.State.Data.ExpirationTimeUtc = newExpirationTime; + logger.LogTTLExpirationTimeUpdated(sessionId, newExpirationTime); + + // Only schedule deletion check on the first interaction when entity is created. + // On subsequent interactions, we just update the expiration time. The scheduled + // CheckAndDeleteIfExpiredAsync will reschedule itself if the entity hasn't expired. + if (isFirstInteraction) + { + this.ScheduleDeletionCheck(sessionId, logger, timeToLive.Value); + } + } + else + { + // TTL is disabled. Clear the expiration time if it was previously set. + if (this.State.Data.ExpirationTimeUtc.HasValue) + { + logger.LogTTLExpirationTimeCleared(sessionId); + this.State.Data.ExpirationTimeUtc = null; + } + } + return response; } finally @@ -121,4 +156,78 @@ async IAsyncEnumerable StreamResultsAsync() DurableAgentContext.ClearCurrent(); } } + + /// + /// Checks if the entity has expired and deletes it if so, otherwise reschedules the deletion check. + /// + /// + /// This method is called by the durable task runtime when a CheckAndDeleteIfExpired signal is received. + /// + public void CheckAndDeleteIfExpired() + { + AgentSessionId sessionId = this.Context.Id; + AIAgent agent = this.GetAgent(sessionId); + ILogger logger = this.GetLogger(agent.Name!, sessionId.Key); + + DateTime currentTime = DateTime.UtcNow; + DateTime? expirationTime = this.State.Data.ExpirationTimeUtc; + + logger.LogTTLDeletionCheck(sessionId, expirationTime, currentTime); + + if (expirationTime.HasValue) + { + if (currentTime >= expirationTime.Value) + { + // Entity has expired, delete it + logger.LogTTLEntityExpired(sessionId, expirationTime.Value); + this.State = null!; + } + else + { + // Entity hasn't expired yet, reschedule the deletion check + TimeSpan? timeToLive = this._options.GetTimeToLive(sessionId.Name); + if (timeToLive.HasValue) + { + this.ScheduleDeletionCheck(sessionId, logger, timeToLive.Value); + } + } + } + } + + private void ScheduleDeletionCheck(AgentSessionId sessionId, ILogger logger, TimeSpan timeToLive) + { + DateTime currentTime = DateTime.UtcNow; + DateTime expirationTime = this.State.Data.ExpirationTimeUtc ?? currentTime.Add(timeToLive); + TimeSpan minimumDelay = this._options.MinimumTimeToLiveSignalDelay; + + // To avoid excessive scheduling, we schedule the deletion check for no less than the minimum delay. + DateTime scheduledTime = expirationTime > currentTime.Add(minimumDelay) + ? expirationTime + : currentTime.Add(minimumDelay); + + logger.LogTTLDeletionScheduled(sessionId, scheduledTime); + + // Schedule a signal to self to check for expiration + this.Context.SignalEntity( + this.Context.Id, + nameof(CheckAndDeleteIfExpired), // self-signal + options: new SignalEntityOptions { SignalTime = scheduledTime }); + } + + private AIAgent GetAgent(AgentSessionId sessionId) + { + IReadOnlyDictionary> agents = + this._services.GetRequiredService>>(); + if (!agents.TryGetValue(sessionId.Name, out Func? agentFactory)) + { + throw new InvalidOperationException($"Agent '{sessionId.Name}' not found"); + } + + return agentFactory(this._services); + } + + private ILogger GetLogger(string agentName, string sessionKey) + { + return this._loggerFactory.CreateLogger($"Microsoft.DurableTask.Agents.{agentName}.{sessionKey}"); + } } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md b/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md index d2cdc7cd41..ccc6aa7181 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md @@ -1,5 +1,12 @@ # Release History +## [Unreleased] + +### Changed + +- Added TTL configuration for durable agent entities ([#2679](https://github.com/microsoft/agent-framework/pull/2679)) +- Switch to new "Run" method name ([#2843](https://github.com/microsoft/agent-framework/pull/2843)) + ## v1.0.0-preview.251204.1 - Added orchestration ID to durable agent entity state ([#2137](https://github.com/microsoft/agent-framework/pull/2137)) diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DefaultDurableAgentClient.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DefaultDurableAgentClient.cs index 2086a00ecb..9005641860 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DefaultDurableAgentClient.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DefaultDurableAgentClient.cs @@ -22,7 +22,7 @@ public async Task RunAgentAsync( await this._client.Entities.SignalEntityAsync( sessionId, - nameof(AgentEntity.RunAgentAsync), + nameof(AgentEntity.Run), request, cancellation: cancellationToken); diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs index 021c8f22c7..2035b792fd 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs @@ -107,7 +107,7 @@ public override async Task RunAsync( { return await this._context.Entities.CallEntityAsync( durableThread.SessionId, - nameof(AgentEntity.RunAgentAsync), + nameof(AgentEntity.Run), request); } catch (EntityOperationFailedException e) when (e.FailureDetails.ErrorType == "EntityTaskNotFound") diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentsOptions.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentsOptions.cs index f2ac3f4c9a..cefcad323a 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentsOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentsOptions.cs @@ -9,23 +9,67 @@ public sealed class DurableAgentsOptions { // Agent names are case-insensitive private readonly Dictionary> _agentFactories = new(StringComparer.OrdinalIgnoreCase); + private readonly Dictionary _agentTimeToLive = new(StringComparer.OrdinalIgnoreCase); internal DurableAgentsOptions() { } + /// + /// Gets or sets the default time-to-live (TTL) for agent entities. + /// + /// + /// If an agent entity is idle for this duration, it will be automatically deleted. + /// Defaults to 14 days. Set to to disable TTL for agents without explicit TTL configuration. + /// + public TimeSpan? DefaultTimeToLive { get; set; } = TimeSpan.FromDays(14); + + /// + /// Gets or sets the minimum delay for scheduling TTL deletion signals. Defaults to 5 minutes. + /// + /// + /// This property is primarily useful for testing (where shorter delays are needed) or for + /// shorter-lived agents in workflows that need more rapid cleanup. The maximum allowed value is 5 minutes. + /// Reducing the minimum deletion delay below 5 minutes can be useful for testing or for ensuring rapid cleanup of short-lived agent sessions. + /// However, this can also increase the load on the system and should be used with caution. + /// + /// Thrown when the value exceeds 5 minutes. + public TimeSpan MinimumTimeToLiveSignalDelay + { + get; + set + { + const int MaximumDelayMinutes = 5; + if (value > TimeSpan.FromMinutes(MaximumDelayMinutes)) + { + throw new ArgumentOutOfRangeException( + nameof(value), + value, + $"The minimum time-to-live signal delay cannot exceed {MaximumDelayMinutes} minutes."); + } + + field = value; + } + } = TimeSpan.FromMinutes(5); + /// /// Adds an AI agent factory to the options. /// /// The name of the agent. /// The factory function to create the agent. + /// Optional time-to-live for this agent's entities. If not specified, uses . /// The options instance. /// Thrown when or is null. - public DurableAgentsOptions AddAIAgentFactory(string name, Func factory) + public DurableAgentsOptions AddAIAgentFactory(string name, Func factory, TimeSpan? timeToLive = null) { ArgumentNullException.ThrowIfNull(name); ArgumentNullException.ThrowIfNull(factory); this._agentFactories.Add(name, factory); + if (timeToLive.HasValue) + { + this._agentTimeToLive[name] = timeToLive; + } + return this; } @@ -50,12 +94,13 @@ public DurableAgentsOptions AddAIAgents(params IEnumerable agents) /// Adds an AI agent to the options. /// /// The agent to add. + /// Optional time-to-live for this agent's entities. If not specified, uses . /// The options instance. /// Thrown when is null. /// /// Thrown when is null or whitespace or when an agent with the same name has already been registered. /// - public DurableAgentsOptions AddAIAgent(AIAgent agent) + public DurableAgentsOptions AddAIAgent(AIAgent agent, TimeSpan? timeToLive = null) { ArgumentNullException.ThrowIfNull(agent); @@ -70,6 +115,11 @@ public DurableAgentsOptions AddAIAgent(AIAgent agent) } this._agentFactories.Add(agent.Name, sp => agent); + if (timeToLive.HasValue) + { + this._agentTimeToLive[agent.Name] = timeToLive; + } + return this; } @@ -81,4 +131,14 @@ internal IReadOnlyDictionary> GetAgentFa { return this._agentFactories.AsReadOnly(); } + + /// + /// Gets the time-to-live for a specific agent, or the default TTL if not specified. + /// + /// The name of the agent. + /// The time-to-live for the agent, or the default TTL if not specified. + internal TimeSpan? GetTimeToLive(string agentName) + { + return this._agentTimeToLive.TryGetValue(agentName, out TimeSpan? ttl) ? ttl : this.DefaultTimeToLive; + } } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/EntityAgentWrapper.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/EntityAgentWrapper.cs index 34c9208967..8822ebcc39 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/EntityAgentWrapper.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/EntityAgentWrapper.cs @@ -19,7 +19,7 @@ internal sealed class EntityAgentWrapper( private readonly IServiceProvider? _entityScopedServices = entityScopedServices; // The ID of the agent is always the entity ID. - public override string Id => this._entityContext.Id.ToString(); + protected override string? IdCore => this._entityContext.Id.ToString(); public override async Task RunAsync( IEnumerable messages, diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/Logs.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/Logs.cs index 0bec1e149c..ba310441df 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/Logs.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/Logs.cs @@ -46,4 +46,58 @@ public static partial void LogAgentResponse( Level = LogLevel.Information, Message = "Found response for agent with session ID '{SessionId}' with correlation ID '{CorrelationId}'")] public static partial void LogDonePollingForResponse(this ILogger logger, AgentSessionId sessionId, string correlationId); + + [LoggerMessage( + EventId = 6, + Level = LogLevel.Information, + Message = "[{SessionId}] TTL expiration time updated to {ExpirationTime:O}")] + public static partial void LogTTLExpirationTimeUpdated( + this ILogger logger, + AgentSessionId sessionId, + DateTime expirationTime); + + [LoggerMessage( + EventId = 7, + Level = LogLevel.Information, + Message = "[{SessionId}] TTL deletion signal scheduled for {ScheduledTime:O}")] + public static partial void LogTTLDeletionScheduled( + this ILogger logger, + AgentSessionId sessionId, + DateTime scheduledTime); + + [LoggerMessage( + EventId = 8, + Level = LogLevel.Information, + Message = "[{SessionId}] TTL deletion check running. Expiration time: {ExpirationTime:O}, Current time: {CurrentTime:O}")] + public static partial void LogTTLDeletionCheck( + this ILogger logger, + AgentSessionId sessionId, + DateTime? expirationTime, + DateTime currentTime); + + [LoggerMessage( + EventId = 9, + Level = LogLevel.Information, + Message = "[{SessionId}] Entity expired and deleted due to TTL. Expiration time: {ExpirationTime:O}")] + public static partial void LogTTLEntityExpired( + this ILogger logger, + AgentSessionId sessionId, + DateTime expirationTime); + + [LoggerMessage( + EventId = 10, + Level = LogLevel.Information, + Message = "[{SessionId}] TTL deletion signal rescheduled for {ScheduledTime:O}")] + public static partial void LogTTLRescheduled( + this ILogger logger, + AgentSessionId sessionId, + DateTime scheduledTime); + + [LoggerMessage( + EventId = 11, + Level = LogLevel.Information, + Message = "[{SessionId}] TTL expiration time cleared (TTL disabled)")] + public static partial void LogTTLExpirationTimeCleared( + this ILogger logger, + AgentSessionId sessionId); } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/ServiceCollectionExtensions.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/ServiceCollectionExtensions.cs index 2f435e0541..79d44924ca 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/ServiceCollectionExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/ServiceCollectionExtensions.cs @@ -85,6 +85,9 @@ internal static DurableAgentsOptions ConfigureDurableAgents( // The agent dictionary contains the real agent factories, which is used by the agent entities. services.AddSingleton(agents); + // Register the options so AgentEntity can access TTL configuration + services.AddSingleton(options); + // The keyed services are used to resolve durable agent *proxy* instances for external clients. foreach (var factory in agents) { diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateData.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateData.cs index f51820dcf5..745f619f48 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateData.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/State/DurableAgentStateData.cs @@ -17,6 +17,13 @@ internal sealed class DurableAgentStateData [JsonPropertyName("conversationHistory")] public IList ConversationHistory { get; init; } = []; + /// + /// Gets or sets the expiration time (UTC) for this agent entity. + /// If the entity is idle beyond this time, it will be automatically deleted. + /// + [JsonPropertyName("expirationTimeUtc")] + public DateTime? ExpirationTimeUtc { get; set; } + /// /// Gets any additional data found during deserialization that does not map to known properties. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/Models/ConversationReference.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/Models/ConversationReference.cs index dc38375331..d5a1d96240 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/Models/ConversationReference.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.OpenAI/Responses/Models/ConversationReference.cs @@ -84,22 +84,18 @@ public override void Write(Utf8JsonWriter writer, ConversationReference value, J return; } - // If only ID is present and no metadata, serialize as a simple string - if (value.Metadata is null || value.Metadata.Count == 0) + // Ideally if only ID is present and no metadata, we would serialize as a simple string. + // However, while a request's "conversation" property can be either a string or an object + // containing a string, a response's "conversation" property is always an object. Since + // here we don't know which scenario we're in, we always serialize as an object, which works + // in any scenario. + writer.WriteStartObject(); + writer.WriteString("id", value.Id); + if (value.Metadata is not null) { - writer.WriteStringValue(value.Id); - } - else - { - // Otherwise, serialize as an object - writer.WriteStartObject(); - writer.WriteString("id", value.Id); - if (value.Metadata is not null) - { - writer.WritePropertyName("metadata"); - JsonSerializer.Serialize(writer, value.Metadata, OpenAIHostingJsonContext.Default.DictionaryStringString); - } - writer.WriteEndObject(); + writer.WritePropertyName("metadata"); + JsonSerializer.Serialize(writer, value.Metadata, OpenAIHostingJsonContext.Default.DictionaryStringString); } + writer.WriteEndObject(); } } diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs index 4abc6915a6..d487ba00e1 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AIAgentWithOpenAIExtensions.cs @@ -73,22 +73,22 @@ public static AsyncCollectionResult RunStreamingA } /// - /// Runs the AI agent with a collection of OpenAI response items and returns the response as a native OpenAI . + /// Runs the AI agent with a collection of OpenAI response items and returns the response as a native OpenAI . /// /// The AI agent to run. /// The collection of OpenAI response items to send to the agent. /// The conversation thread to continue with this invocation. If not provided, creates a new thread. The thread will be mutated with the provided messages and agent response. /// Optional parameters for agent invocation. /// The to monitor for cancellation requests. The default is . - /// A representing the asynchronous operation that returns a native OpenAI response. + /// A representing the asynchronous operation that returns a native OpenAI response. /// Thrown when or is . - /// Thrown when the agent's response cannot be converted to an , typically when the underlying representation is not an OpenAI response. + /// Thrown when the agent's response cannot be converted to an , typically when the underlying representation is not an OpenAI response. /// Thrown when any message in has a type that is not supported by the message conversion method. /// /// This method converts the OpenAI response items to the Microsoft Extensions AI format using the appropriate conversion method, - /// runs the agent with the converted message collection, and then extracts the native OpenAI from the response using . + /// runs the agent with the converted message collection, and then extracts the native OpenAI from the response using . /// - public static async Task RunAsync(this AIAgent agent, IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + public static async Task RunAsync(this AIAgent agent, IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { Throw.IfNull(agent); Throw.IfNull(messages); diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentRunResponseExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentRunResponseExtensions.cs index 9a164d862b..44844e64f5 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentRunResponseExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/AgentRunResponseExtensions.cs @@ -29,17 +29,17 @@ response.RawRepresentation as ChatCompletion ?? } /// - /// Creates or extracts a native OpenAI object from an . + /// Creates or extracts a native OpenAI object from an . /// /// The agent response. - /// The OpenAI object. + /// The OpenAI object. /// is . - public static OpenAIResponse AsOpenAIResponse(this AgentRunResponse response) + public static ResponseResult AsOpenAIResponse(this AgentRunResponse response) { Throw.IfNull(response); return - response.RawRepresentation as OpenAIResponse ?? - response.AsChatResponse().AsOpenAIResponse(); + response.RawRepresentation as ResponseResult ?? + response.AsChatResponse().AsOpenAIResponseResult(); } } diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs index 0d48147c77..224bf5db95 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIResponseClientExtensions.cs @@ -8,7 +8,7 @@ namespace OpenAI.Responses; /// -/// Provides extension methods for +/// Provides extension methods for /// to simplify the creation of AI agents that work with OpenAI services. /// /// @@ -20,9 +20,9 @@ namespace OpenAI.Responses; public static class OpenAIResponseClientExtensions { /// - /// Creates an AI agent from an using the OpenAI Response API. + /// Creates an AI agent from an using the OpenAI Response API. /// - /// The to use for the agent. + /// The to use for the agent. /// Optional system instructions that define the agent's behavior and personality. /// Optional name for the agent for identification purposes. /// Optional description of the agent's capabilities and purpose. @@ -33,7 +33,7 @@ public static class OpenAIResponseClientExtensions /// An instance backed by the OpenAI Response service. /// Thrown when is . public static ChatClientAgent CreateAIAgent( - this OpenAIResponseClient client, + this ResponsesClient client, string? instructions = null, string? name = null, string? description = null, @@ -61,9 +61,9 @@ public static ChatClientAgent CreateAIAgent( } /// - /// Creates an AI agent from an using the OpenAI Response API. + /// Creates an AI agent from an using the OpenAI Response API. /// - /// The to use for the agent. + /// The to use for the agent. /// Full set of options to configure the agent. /// Provides a way to customize the creation of the underlying used by the agent. /// Optional logger factory for enabling logging within the agent. @@ -71,7 +71,7 @@ public static ChatClientAgent CreateAIAgent( /// An instance backed by the OpenAI Response service. /// Thrown when or is . public static ChatClientAgent CreateAIAgent( - this OpenAIResponseClient client, + this ResponsesClient client, ChatClientAgentOptions options, Func? clientFactory = null, ILoggerFactory? loggerFactory = null, diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.AzureAI/AzureAgentProvider.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.AzureAI/AzureAgentProvider.cs index c4a613901c..d4010a43c2 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.AzureAI/AzureAgentProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.AzureAI/AzureAgentProvider.cs @@ -111,7 +111,7 @@ public override async IAsyncEnumerable InvokeAgentAsync( if (inputArguments is not null) { JsonNode jsonNode = ConvertDictionaryToJson(inputArguments); - ResponseCreationOptions responseCreationOptions = new(); + CreateResponseOptions responseCreationOptions = new(); #pragma warning disable SCME0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. responseCreationOptions.Patch.Set("$.structured_inputs"u8, BinaryData.FromString(jsonNode.ToJsonString())); #pragma warning restore SCME0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. @@ -206,7 +206,7 @@ private async Task GetAgentAsync(AgentVersion agentVersion, Cancellatio public override async Task GetMessageAsync(string conversationId, string messageId, CancellationToken cancellationToken = default) { AgentResponseItem responseItem = await this.GetConversationClient().GetProjectConversationItemAsync(conversationId, messageId, include: null, cancellationToken).ConfigureAwait(false); - ResponseItem[] items = [responseItem.AsOpenAIResponseItem()]; + ResponseItem[] items = [responseItem.AsResponseResultItem()]; return items.AsChatMessages().Single(); } @@ -223,7 +223,7 @@ public override async IAsyncEnumerable GetMessagesAsync( await foreach (AgentResponseItem responseItem in this.GetConversationClient().GetProjectConversationItemsAsync(conversationId, null, limit, order.ToString(), after, before, include: null, cancellationToken).ConfigureAwait(false)) { - ResponseItem[] items = [responseItem.AsOpenAIResponseItem()]; + ResponseItem[] items = [responseItem.AsResponseResultItem()]; foreach (ChatMessage message in items.AsChatMessages()) { yield return message; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/DirectEdgeData.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/DirectEdgeData.cs index 2119bd775b..7d61c939cd 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/DirectEdgeData.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/DirectEdgeData.cs @@ -11,7 +11,7 @@ namespace Microsoft.Agents.AI.Workflows; /// public sealed class DirectEdgeData : EdgeData { - internal DirectEdgeData(string sourceId, string sinkId, EdgeId id, PredicateT? condition = null) : base(id) + internal DirectEdgeData(string sourceId, string sinkId, EdgeId id, PredicateT? condition = null, string? label = null) : base(id, label) { this.SourceId = sourceId; this.SinkId = sinkId; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/EdgeData.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/EdgeData.cs index 7771b3966e..570bc79bc0 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/EdgeData.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/EdgeData.cs @@ -14,10 +14,16 @@ public abstract class EdgeData /// internal abstract EdgeConnection Connection { get; } - internal EdgeData(EdgeId id) + internal EdgeData(EdgeId id, string? label = null) { this.Id = id; + this.Label = label; } internal EdgeId Id { get; } + + /// + /// An optional label for the edge, allowing for arbitrary metadata to be associated with it. + /// + public string? Label { get; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/FanInEdgeData.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/FanInEdgeData.cs index 0cb2b38378..1132fca334 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/FanInEdgeData.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/FanInEdgeData.cs @@ -10,7 +10,7 @@ namespace Microsoft.Agents.AI.Workflows; /// internal sealed class FanInEdgeData : EdgeData { - internal FanInEdgeData(List sourceIds, string sinkId, EdgeId id) : base(id) + internal FanInEdgeData(List sourceIds, string sinkId, EdgeId id, string? label) : base(id, label) { this.SourceIds = sourceIds; this.SinkId = sinkId; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/FanOutEdgeData.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/FanOutEdgeData.cs index 9d9ddf4cea..86a940c1b6 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/FanOutEdgeData.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/FanOutEdgeData.cs @@ -13,7 +13,7 @@ namespace Microsoft.Agents.AI.Workflows; /// internal sealed class FanOutEdgeData : EdgeData { - internal FanOutEdgeData(string sourceId, List sinkIds, EdgeId edgeId, AssignerF? assigner = null) : base(edgeId) + internal FanOutEdgeData(string sourceId, List sinkIds, EdgeId edgeId, AssignerF? assigner = null, string? label = null) : base(edgeId, label) { this.SourceId = sourceId; this.SinkIds = sinkIds; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Visualization/WorkflowVisualizer.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Visualization/WorkflowVisualizer.cs index ebf6f08ffb..e1b69e9f9e 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Visualization/WorkflowVisualizer.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Visualization/WorkflowVisualizer.cs @@ -99,10 +99,30 @@ private static void EmitWorkflowDigraph(Workflow workflow, List lines, s } // Emit normal edges - foreach (var (src, target, isConditional) in ComputeNormalEdges(workflow)) + foreach (var (src, target, isConditional, label) in ComputeNormalEdges(workflow)) { - var edgeAttr = isConditional ? " [style=dashed, label=\"conditional\"]" : ""; - lines.Add($"{indent}\"{MapId(src)}\" -> \"{MapId(target)}\"{edgeAttr};"); + // Build edge attributes + var attributes = new List(); + + // Add style for conditional edges + if (isConditional) + { + attributes.Add("style=dashed"); + } + + // Add label (custom label or default "conditional" for conditional edges) + if (label != null) + { + attributes.Add($"label=\"{EscapeDotLabel(label)}\""); + } + else if (isConditional) + { + attributes.Add("label=\"conditional\""); + } + + // Combine attributes + var attrString = attributes.Count > 0 ? $" [{string.Join(", ", attributes)}]" : ""; + lines.Add($"{indent}\"{MapId(src)}\" -> \"{MapId(target)}\"{attrString};"); } } @@ -133,12 +153,7 @@ private static void EmitSubWorkflowsDigraph(Workflow workflow, List line private static void EmitWorkflowMermaid(Workflow workflow, List lines, string indent, string? ns = null) { - string sanitize(string input) - { - return input; - } - - string MapId(string id) => ns != null ? $"{sanitize(ns)}/{sanitize(id)}" : id; + string MapId(string id) => ns != null ? $"{ns}/{id}" : id; // Add start node var startExecutorId = workflow.StartExecutorId; @@ -175,14 +190,23 @@ string sanitize(string input) } // Emit normal edges - foreach (var (src, target, isConditional) in ComputeNormalEdges(workflow)) + foreach (var (src, target, isConditional, label) in ComputeNormalEdges(workflow)) { if (isConditional) { - lines.Add($"{indent}{MapId(src)} -. conditional .--> {MapId(target)};"); + string effectiveLabel = label != null ? EscapeMermaidLabel(label) : "conditional"; + + // Conditional edge, with user label or default + lines.Add($"{indent}{MapId(src)} -. {effectiveLabel} .--> {MapId(target)};"); + } + else if (label != null) + { + // Regular edge with label + lines.Add($"{indent}{MapId(src)} -->|{EscapeMermaidLabel(label)}| {MapId(target)};"); } else { + // Regular edge without label lines.Add($"{indent}{MapId(src)} --> {MapId(target)};"); } } @@ -214,9 +238,9 @@ string sanitize(string input) return result; } - private static List<(string Source, string Target, bool IsConditional)> ComputeNormalEdges(Workflow workflow) + private static List<(string Source, string Target, bool IsConditional, string? Label)> ComputeNormalEdges(Workflow workflow) { - var edges = new List<(string, string, bool)>(); + var edges = new List<(string, string, bool, string?)>(); foreach (var edgeGroup in workflow.Edges.Values.SelectMany(x => x)) { if (edgeGroup.Kind == EdgeKind.FanIn) @@ -229,14 +253,15 @@ string sanitize(string input) case EdgeKind.Direct when edgeGroup.DirectEdgeData != null: var directData = edgeGroup.DirectEdgeData; var isConditional = directData.Condition != null; - edges.Add((directData.SourceId, directData.SinkId, isConditional)); + var label = directData.Label; + edges.Add((directData.SourceId, directData.SinkId, isConditional, label)); break; case EdgeKind.FanOut when edgeGroup.FanOutEdgeData != null: var fanOutData = edgeGroup.FanOutEdgeData; foreach (var sinkId in fanOutData.SinkIds) { - edges.Add((fanOutData.SourceId, sinkId, false)); + edges.Add((fanOutData.SourceId, sinkId, false, fanOutData.Label)); } break; } @@ -276,5 +301,24 @@ private static bool TryGetNestedWorkflow(ExecutorBinding binding, [NotNullWhen(t return false; } + // Helper method to escape special characters in DOT labels + private static string EscapeDotLabel(string label) + { + return label.Replace("\"", "\\\"").Replace("\n", "\\n"); + } + + // Helper method to escape special characters in Mermaid labels + private static string EscapeMermaidLabel(string label) + { + return label + .Replace("&", "&") // Must be first to avoid double-escaping + .Replace("|", "|") // Pipe breaks Mermaid delimiter syntax + .Replace("\"", """) // Quote character + .Replace("<", "<") // Less than + .Replace(">", ">") // Greater than + .Replace("\n", "
") // Newline to HTML break + .Replace("\r", ""); // Remove carriage return + } + #endregion } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowBuilder.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowBuilder.cs index 93f7850135..4b6980d433 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowBuilder.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowBuilder.cs @@ -168,6 +168,18 @@ private HashSet EnsureEdgesFor(string sourceId) return edges; } + /// + /// Adds a directed edge from the specified source executor to the target executor, optionally guarded by a + /// condition. + /// + /// The executor that acts as the source node of the edge. Cannot be null. + /// The executor that acts as the target node of the edge. Cannot be null. + /// The current instance of . + /// Thrown if an unconditional edge between the specified source and target + /// executors already exists. + public WorkflowBuilder AddEdge(ExecutorBinding source, ExecutorBinding target) + => this.AddEdge(source, target, null, false); + /// /// Adds a directed edge from the specified source executor to the target executor, optionally guarded by a /// condition. @@ -182,6 +194,20 @@ private HashSet EnsureEdgesFor(string sourceId) public WorkflowBuilder AddEdge(ExecutorBinding source, ExecutorBinding target, bool idempotent = false) => this.AddEdge(source, target, null, idempotent); + /// + /// Adds a directed edge from the specified source executor to the target executor. + /// + /// The executor that acts as the source node of the edge. Cannot be null. + /// The executor that acts as the target node of the edge. Cannot be null. + /// An optional label for the edge. Will be used in visualizations. + /// If set to , adding the same edge multiple times will be a NoOp, + /// rather than an error. + /// The current instance of . + /// Thrown if an unconditional edge between the specified source and target + /// executors already exists. + public WorkflowBuilder AddEdge(ExecutorBinding source, ExecutorBinding target, string? label = null, bool idempotent = false) + => this.AddEdge(source, target, null, label, idempotent); + internal static Func? CreateConditionFunc(Func? condition) { if (condition is null) @@ -222,6 +248,20 @@ public WorkflowBuilder AddEdge(ExecutorBinding source, ExecutorBinding target, b private EdgeId TakeEdgeId() => new(Interlocked.Increment(ref this._edgeCount)); + /// + /// Adds a directed edge from the specified source executor to the target executor, optionally guarded by a + /// condition. + /// + /// The executor that acts as the source node of the edge. Cannot be null. + /// The executor that acts as the target node of the edge. Cannot be null. + /// An optional predicate that determines whether the edge should be followed based on the input. + /// If null, the edge is always activated when the source sends a message. + /// The current instance of . + /// Thrown if an unconditional edge between the specified source and target + /// executors already exists. + public WorkflowBuilder AddEdge(ExecutorBinding source, ExecutorBinding target, Func? condition = null) + => this.AddEdge(source, target, condition, label: null, false); + /// /// Adds a directed edge from the specified source executor to the target executor, optionally guarded by a /// condition. @@ -236,6 +276,23 @@ public WorkflowBuilder AddEdge(ExecutorBinding source, ExecutorBinding target, b /// Thrown if an unconditional edge between the specified source and target /// executors already exists. public WorkflowBuilder AddEdge(ExecutorBinding source, ExecutorBinding target, Func? condition = null, bool idempotent = false) + => this.AddEdge(source, target, condition, label: null, idempotent); + + /// + /// Adds a directed edge from the specified source executor to the target executor, optionally guarded by a + /// condition. + /// + /// The executor that acts as the source node of the edge. Cannot be null. + /// The executor that acts as the target node of the edge. Cannot be null. + /// An optional predicate that determines whether the edge should be followed based on the input. + /// An optional label for the edge. Will be used in visualizations. + /// If set to , adding the same edge multiple times will be a NoOp, + /// rather than an error. + /// If null, the edge is always activated when the source sends a message. + /// The current instance of . + /// Thrown if an unconditional edge between the specified source and target + /// executors already exists. + public WorkflowBuilder AddEdge(ExecutorBinding source, ExecutorBinding target, Func? condition = null, string? label = null, bool idempotent = false) { // Add an edge from source to target with an optional condition. // This is a low-level builder method that does not enforce any specific executor type. @@ -256,7 +313,7 @@ public WorkflowBuilder AddEdge(ExecutorBinding source, ExecutorBinding target "You cannot add another edge without a condition for the same source and target."); } - DirectEdgeData directEdge = new(this.Track(source).Id, this.Track(target).Id, this.TakeEdgeId(), CreateConditionFunc(condition)); + DirectEdgeData directEdge = new(this.Track(source).Id, this.Track(target).Id, this.TakeEdgeId(), CreateConditionFunc(condition), label); this.EnsureEdgesFor(source.Id).Add(new(directEdge)); @@ -275,6 +332,19 @@ public WorkflowBuilder AddEdge(ExecutorBinding source, ExecutorBinding target public WorkflowBuilder AddFanOutEdge(ExecutorBinding source, IEnumerable targets) => this.AddFanOutEdge(source, targets, null); + /// + /// Adds a fan-out edge from the specified source executor to one or more target executors, optionally using a + /// custom partitioning function. + /// + /// If a partitioner function is provided, it will be used to distribute input across the target + /// executors. The order of targets determines their mapping in the partitioning process. + /// The source executor from which the fan-out edge originates. Cannot be null. + /// One or more target executors that will receive the fan-out edge. Cannot be null or empty. + /// A label for the edge. Will be used in visualization. + /// The current instance of . + public WorkflowBuilder AddFanOutEdge(ExecutorBinding source, IEnumerable targets, string label) + => this.AddFanOutEdge(source, targets, null, label); + internal static Func>? CreateTargetAssignerFunc(Func>? targetAssigner) { if (targetAssigner is null) @@ -305,6 +375,21 @@ public WorkflowBuilder AddFanOutEdge(ExecutorBinding source, IEnumerableAn optional function that determines how input is assigned among the target executors. /// If null, messages will route to all targets. public WorkflowBuilder AddFanOutEdge(ExecutorBinding source, IEnumerable targets, Func>? targetSelector = null) + => this.AddFanOutEdge(source, targets, targetSelector, label: null); + + /// + /// Adds a fan-out edge from the specified source executor to one or more target executors, optionally using a + /// custom partitioning function. + /// + /// If a partitioner function is provided, it will be used to distribute input across the target + /// executors. The order of targets determines their mapping in the partitioning process. + /// The source executor from which the fan-out edge originates. Cannot be null. + /// One or more target executors that will receive the fan-out edge. Cannot be null or empty. + /// The current instance of . + /// An optional function that determines how input is assigned among the target executors. + /// If null, messages will route to all targets. + /// An optional label for the edge. Will be used in visualizations. + public WorkflowBuilder AddFanOutEdge(ExecutorBinding source, IEnumerable targets, Func>? targetSelector = null, string? label = null) { Throw.IfNull(source); Throw.IfNull(targets); @@ -321,7 +406,8 @@ public WorkflowBuilder AddFanOutEdge(ExecutorBinding source, IEnumerable(ExecutorBinding source, IEnumerableThe target executor that receives input from the specified source executors. Cannot be null. /// The current instance of . public WorkflowBuilder AddFanInEdge(IEnumerable sources, ExecutorBinding target) + => this.AddFanInEdge(sources, target, label: null); + + /// + /// Adds a fan-in edge to the workflow, connecting multiple source executors to a single target executor with an + /// optional trigger condition. + /// + /// This method establishes a fan-in relationship, allowing the target executor to be activated + /// based on the completion or state of multiple sources. The trigger parameter can be used to customize activation + /// behavior. + /// One or more source executors that provide input to the target. Cannot be null or empty. + /// The target executor that receives input from the specified source executors. Cannot be null. + /// An optional label for the edge. Will be used in visualizations. + /// The current instance of . + public WorkflowBuilder AddFanInEdge(IEnumerable sources, ExecutorBinding target, string? label = null) { Throw.IfNull(target); Throw.IfNull(sources); @@ -354,7 +454,8 @@ public WorkflowBuilder AddFanInEdge(IEnumerable sources, Execut FanInEdgeData edgeData = new( sourceIds, this.Track(target).Id, - this.TakeEdgeId()); + this.TakeEdgeId(), + label); foreach (string sourceId in edgeData.SourceIds) { diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs index 98dc5903bf..70fcee15df 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs @@ -39,7 +39,7 @@ public WorkflowHostAgent(Workflow workflow, string? id = null, string? name = nu this._describeTask = this._workflow.DescribeProtocolAsync().AsTask(); } - public override string Id => this._id ?? base.Id; + protected override string? IdCore => this._id; public override string? Name { get; } public override string? Description { get; } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index df7477241c..a5a34d24a9 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -121,7 +121,7 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options, public IChatClient ChatClient { get; } /// - public override string Id => this._agentOptions?.Id ?? base.Id; + protected override string? IdCore => this._agentOptions?.Id; /// public override string? Name => this._agentOptions?.Name; diff --git a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs index e982c8081f..883b317f5e 100644 --- a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs +++ b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs @@ -89,7 +89,7 @@ private async Task> GetChatHistoryFromConversationAsync(string List messages = []; await foreach (AgentResponseItem item in this._client.GetProjectOpenAIClient().GetProjectConversationsClient().GetProjectConversationItemsAsync(conversationId, order: "asc")) { - var openAIItem = item.AsOpenAIResponseItem(); + var openAIItem = item.AsResponseResultItem(); if (openAIItem is MessageResponseItem messageItem) { messages.Add(new ChatMessage diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs index 5111a97ad1..e3bda2081a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs @@ -214,13 +214,31 @@ public async Task InvokeStreamingWithSingleMessageCallsMockedInvokeWithMessageIn [Fact] public void ValidateAgentIDIsIdempotent() { + // Arrange var agent = new MockAgent(); + // Act string id = agent.Id; + + // Assert Assert.NotNull(id); Assert.Equal(id, agent.Id); } + [Fact] + public void ValidateAgentIDCanBeProvidedByDerivedAgentClass() + { + // Arrange + var agent = new MockAgent(id: "test-agent-id"); + + // Act + string id = agent.Id; + + // Assert + Assert.NotNull(id); + Assert.Equal("test-agent-id", id); + } + #region GetService Method Tests /// @@ -344,6 +362,13 @@ public abstract class TestAgentThread : AgentThread; private sealed class MockAgent : AIAgent { + public MockAgent(string? id = null) + { + this.IdCore = id; + } + + protected override string? IdCore { get; } + public override AgentThread GetNewThread() => throw new NotImplementedException(); diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs index 4dca99a77c..50271b7eee 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs @@ -6,6 +6,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; +using Moq.Protected; namespace Microsoft.Agents.AI.Abstractions.UnitTests; @@ -31,7 +32,7 @@ public DelegatingAIAgentTests() this._testThread = new TestAgentThread(); // Setup inner agent mock - this._innerAgentMock.Setup(x => x.Id).Returns("test-agent-id"); + this._innerAgentMock.Protected().SetupGet("IdCore").Returns("test-agent-id"); this._innerAgentMock.Setup(x => x.Name).Returns("Test Agent"); this._innerAgentMock.Setup(x => x.Description).Returns("Test Description"); this._innerAgentMock.Setup(x => x.GetNewThread()).Returns(this._testThread); @@ -93,7 +94,7 @@ public void Id_DelegatesToInnerAgent() // Assert Assert.Equal("test-agent-id", id); - this._innerAgentMock.Verify(x => x.Id, Times.Once); + this._innerAgentMock.Protected().VerifyGet("IdCore", Times.Once()); } /// diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs index 7ce0611c99..3dbd3ec367 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs @@ -10,7 +10,6 @@ using Azure.Identity; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.AI; -using Xunit; namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests; @@ -59,6 +58,9 @@ public sealed class CosmosChatMessageStoreTests : IAsyncLifetime, IDisposable public async Task InitializeAsync() { + // Fail fast if emulator is not available + this.SkipIfEmulatorNotAvailable(); + // Check environment variable to determine if we should preserve containers // Set COSMOS_PRESERVE_CONTAINERS=true to keep containers and data for inspection this._preserveContainer = string.Equals(Environment.GetEnvironmentVariable("COSMOS_PRESERVE_CONTAINERS"), "true", StringComparison.OrdinalIgnoreCase); diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosCheckpointStoreTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosCheckpointStoreTests.cs index 8f5749b187..dc75b34758 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosCheckpointStoreTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosCheckpointStoreTests.cs @@ -7,7 +7,6 @@ using Microsoft.Agents.AI.Workflows; using Microsoft.Agents.AI.Workflows.Checkpointing; using Microsoft.Azure.Cosmos; -using Xunit; namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests; @@ -58,6 +57,9 @@ private static JsonSerializerOptions CreateJsonOptions() public async Task InitializeAsync() { + // Fail fast if emulator is not available + this.SkipIfEmulatorNotAvailable(); + // Check environment variable to determine if we should preserve containers // Set COSMOS_PRESERVE_CONTAINERS=true to keep containers and data for inspection this._preserveContainer = string.Equals(Environment.GetEnvironmentVariable("COSMOS_PRESERVE_CONTAINERS"), "true", StringComparison.OrdinalIgnoreCase); diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs index 98e40ad4fb..b615bf1cd6 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/AgentEntityTests.cs @@ -81,6 +81,64 @@ await simpleAgentProxy.RunAsync( Assert.Null(request.OrchestrationId); } + [Theory] + [InlineData("run")] + [InlineData("Run")] + [InlineData("RunAgentAsync")] + public async Task RunAgentMethodNamesAllWorkAsync(string runAgentMethodName) + { + // Setup + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + name: "TestAgent", + instructions: "You are a helpful assistant that always responds with a friendly greeting." + ); + + using TestHelper testHelper = TestHelper.Start([simpleAgent], this._outputHelper); + + // A proxy agent is needed to call the hosted test agent + AIAgent simpleAgentProxy = simpleAgent.AsDurableAgentProxy(testHelper.Services); + + AgentThread thread = simpleAgentProxy.GetNewThread(); + + DurableTaskClient client = testHelper.GetClient(); + + AgentSessionId sessionId = thread.GetService(); + EntityInstanceId expectedEntityId = new($"dafx-{simpleAgent.Name}", sessionId.Key); + + EntityMetadata? entity = await client.Entities.GetEntityAsync(expectedEntityId, false, this.TestTimeoutToken); + + Assert.Null(entity); + + // Act: send a prompt to the agent + await client.Entities.SignalEntityAsync( + expectedEntityId, + runAgentMethodName, + new RunRequest("Hello!"), + cancellation: this.TestTimeoutToken); + + while (!this.TestTimeoutToken.IsCancellationRequested) + { + await Task.Delay(500, this.TestTimeoutToken); + + // Assert: verify the agent state was stored with the correct entity name prefix + entity = await client.Entities.GetEntityAsync(expectedEntityId, true, this.TestTimeoutToken); + + if (entity is not null) + { + break; + } + } + + Assert.NotNull(entity); + Assert.True(entity.IncludesState); + + DurableAgentState state = entity.State.ReadAs(); + + DurableAgentStateRequest request = Assert.Single(state.Data.ConversationHistory.OfType()); + + Assert.Null(request.OrchestrationId); + } + [Fact] public async Task OrchestrationIdSetDuringOrchestrationAsync() { diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/TimeToLiveTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/TimeToLiveTests.cs new file mode 100644 index 0000000000..25d40a1c5a --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.IntegrationTests/TimeToLiveTests.cs @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics; +using System.Reflection; +using Microsoft.Agents.AI.DurableTask.State; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.Extensions.Configuration; +using OpenAI.Chat; +using Xunit.Abstractions; + +namespace Microsoft.Agents.AI.DurableTask.IntegrationTests; + +/// +/// Tests for Time-To-Live (TTL) functionality of durable agent entities. +/// +[Collection("Sequential")] +[Trait("Category", "Integration")] +public sealed class TimeToLiveTests(ITestOutputHelper outputHelper) : IDisposable +{ + private static readonly TimeSpan s_defaultTimeout = Debugger.IsAttached + ? TimeSpan.FromMinutes(5) + : TimeSpan.FromSeconds(30); + + private static readonly IConfiguration s_configuration = + new ConfigurationBuilder() + .AddUserSecrets(Assembly.GetExecutingAssembly()) + .AddEnvironmentVariables() + .Build(); + + private readonly ITestOutputHelper _outputHelper = outputHelper; + private readonly CancellationTokenSource _cts = new(delay: s_defaultTimeout); + + private CancellationToken TestTimeoutToken => this._cts.Token; + + public void Dispose() => this._cts.Dispose(); + + [Fact] + public async Task EntityExpiresAfterTTLAsync() + { + // Arrange: Create agent with short TTL (10 seconds) + TimeSpan ttl = TimeSpan.FromSeconds(10); + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + name: "TTLTestAgent", + instructions: "You are a helpful assistant." + ); + + using TestHelper testHelper = TestHelper.Start( + this._outputHelper, + options => + { + options.DefaultTimeToLive = ttl; + options.MinimumTimeToLiveSignalDelay = TimeSpan.FromSeconds(1); + options.AddAIAgent(simpleAgent); + }); + + AIAgent agentProxy = simpleAgent.AsDurableAgentProxy(testHelper.Services); + AgentThread thread = agentProxy.GetNewThread(); + DurableTaskClient client = testHelper.GetClient(); + AgentSessionId sessionId = thread.GetService(); + + // Act: Send a message to the agent + await agentProxy.RunAsync( + message: "Hello!", + thread, + cancellationToken: this.TestTimeoutToken); + + // Verify entity exists and get expiration time + EntityMetadata? entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + Assert.NotNull(entity); + Assert.True(entity.IncludesState); + + DurableAgentState state = entity.State.ReadAs(); + Assert.NotNull(state.Data.ExpirationTimeUtc); + DateTime expirationTime = state.Data.ExpirationTimeUtc.Value; + Assert.True(expirationTime > DateTime.UtcNow); + + // Calculate how long to wait: expiration time + buffer for signal processing + TimeSpan waitTime = expirationTime - DateTime.UtcNow + TimeSpan.FromSeconds(1); + if (waitTime > TimeSpan.Zero) + { + await Task.Delay(waitTime, this.TestTimeoutToken); + } + + // Poll the entity state until it's deleted (with timeout) + DateTime pollTimeout = DateTime.UtcNow.AddSeconds(10); + bool entityDeleted = false; + while (DateTime.UtcNow < pollTimeout && !entityDeleted) + { + entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + entityDeleted = entity is null; + + if (!entityDeleted) + { + await Task.Delay(TimeSpan.FromSeconds(1), this.TestTimeoutToken); + } + } + + // Assert: Verify entity state is deleted + Assert.True(entityDeleted, "Entity should have been deleted after TTL expiration"); + } + + [Fact] + public async Task EntityTTLResetsOnInteractionAsync() + { + // Arrange: Create agent with short TTL + TimeSpan ttl = TimeSpan.FromSeconds(6); + AIAgent simpleAgent = TestHelper.GetAzureOpenAIChatClient(s_configuration).CreateAIAgent( + name: "TTLResetTestAgent", + instructions: "You are a helpful assistant." + ); + + using TestHelper testHelper = TestHelper.Start( + this._outputHelper, + options => + { + options.DefaultTimeToLive = ttl; + options.MinimumTimeToLiveSignalDelay = TimeSpan.FromSeconds(1); + options.AddAIAgent(simpleAgent); + }); + + AIAgent agentProxy = simpleAgent.AsDurableAgentProxy(testHelper.Services); + AgentThread thread = agentProxy.GetNewThread(); + DurableTaskClient client = testHelper.GetClient(); + AgentSessionId sessionId = thread.GetService(); + + // Act: Send first message + await agentProxy.RunAsync( + message: "Hello!", + thread, + cancellationToken: this.TestTimeoutToken); + + EntityMetadata? entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + Assert.NotNull(entity); + Assert.True(entity.IncludesState); + + DurableAgentState state = entity.State.ReadAs(); + DateTime firstExpirationTime = state.Data.ExpirationTimeUtc!.Value; + + // Wait partway through TTL + await Task.Delay(TimeSpan.FromSeconds(3), this.TestTimeoutToken); + + // Send second message (should reset TTL) + await agentProxy.RunAsync( + message: "Hello again!", + thread, + cancellationToken: this.TestTimeoutToken); + + // Verify expiration time was updated + entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + Assert.NotNull(entity); + Assert.True(entity.IncludesState); + + state = entity.State.ReadAs(); + DateTime secondExpirationTime = state.Data.ExpirationTimeUtc!.Value; + Assert.True(secondExpirationTime > firstExpirationTime); + + // Calculate when the original expiration time would have been + DateTime originalExpirationTime = firstExpirationTime; + TimeSpan waitUntilOriginalExpiration = originalExpirationTime - DateTime.UtcNow + TimeSpan.FromSeconds(2); + + if (waitUntilOriginalExpiration > TimeSpan.Zero) + { + await Task.Delay(waitUntilOriginalExpiration, this.TestTimeoutToken); + } + + // Assert: Entity should still exist because TTL was reset + // The new expiration time should be in the future + entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + Assert.NotNull(entity); + Assert.True(entity.IncludesState); + + state = entity.State.ReadAs(); + Assert.NotNull(state); + Assert.NotNull(state.Data.ExpirationTimeUtc); + Assert.True( + state.Data.ExpirationTimeUtc > DateTime.UtcNow, + "Entity should still be valid because TTL was reset"); + + // Wait for the entity to be deleted + DateTime pollTimeout = DateTime.UtcNow.AddSeconds(10); + bool entityDeleted = false; + while (DateTime.UtcNow < pollTimeout && !entityDeleted) + { + entity = await client.Entities.GetEntityAsync(sessionId, true, this.TestTimeoutToken); + entityDeleted = entity is null; + + if (!entityDeleted) + { + await Task.Delay(TimeSpan.FromSeconds(1), this.TestTimeoutToken); + } + } + + // Assert: Entity should have been deleted + Assert.True(entityDeleted, "Entity should have been deleted after TTL expiration"); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs index 5bc4e8afad..69560421cf 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs @@ -276,15 +276,9 @@ public async ValueTask DisposeAsync() [SuppressMessage("Performance", "CA1812:Avoid uninstantiated internal classes", Justification = "Instantiated via dependency injection")] internal sealed class FakeChatClientAgent : AIAgent { - public FakeChatClientAgent() - { - this.Id = "fake-agent"; - this.Description = "A fake agent for testing"; - } - - public override string Id { get; } + protected override string? IdCore => "fake-agent"; - public override string? Description { get; } + public override string? Description => "A fake agent for testing"; public override AgentThread GetNewThread() { @@ -350,15 +344,9 @@ public FakeInMemoryAgentThread(JsonElement serializedThread, JsonSerializerOptio [SuppressMessage("Performance", "CA1812:Avoid uninstantiated internal classes", Justification = "Instantiated via dependency injection")] internal sealed class FakeMultiMessageAgent : AIAgent { - public FakeMultiMessageAgent() - { - this.Id = "fake-multi-message-agent"; - this.Description = "A fake agent that sends multiple messages for testing"; - } - - public override string Id { get; } + protected override string? IdCore => "fake-multi-message-agent"; - public override string? Description { get; } + public override string? Description => "A fake agent that sends multiple messages for testing"; public override AgentThread GetNewThread() { diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs index 78a3048747..3e80a58369 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs @@ -421,7 +421,7 @@ private static List ParseSseEvents(string responseContent) private sealed class MultiResponseAgent : AIAgent { - public override string Id => "multi-response-agent"; + protected override string? IdCore => "multi-response-agent"; public override string? Description => "Agent that produces multiple text chunks"; @@ -510,7 +510,7 @@ public TestInMemoryAgentThread(JsonElement serializedThreadState, JsonSerializer private sealed class TestAgent : AIAgent { - public override string Id => "test-agent"; + protected override string? IdCore => "test-agent"; public override string? Description => "Test agent"; diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.OpenAI.UnitTests/OpenAIResponsesIntegrationTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.OpenAI.UnitTests/OpenAIResponsesIntegrationTests.cs index abf66a732f..2dd5b85e5f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.OpenAI.UnitTests/OpenAIResponsesIntegrationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.OpenAI.UnitTests/OpenAIResponsesIntegrationTests.cs @@ -49,7 +49,7 @@ public async Task CreateResponseStreaming_WithSimpleMessage_ReturnsStreamingUpda const string ExpectedResponse = "One Two Three"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Count to 3"); @@ -90,10 +90,10 @@ public async Task CreateResponse_WithSimpleMessage_ReturnsCompleteResponseAsync( const string ExpectedResponse = "Hello! How can I help you today?"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Hello"); + ResponseResult response = await responseClient.CreateResponseAsync("Hello"); // Assert Assert.NotNull(response); @@ -117,7 +117,7 @@ public async Task CreateResponseStreaming_WithMultipleChunks_StreamsAllContentAs const string ExpectedResponse = "This is a test response with multiple words"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -162,12 +162,12 @@ public async Task CreateResponse_WithMultipleAgents_EachAgentRespondsCorrectlyAs (Agent1Name, Agent1Instructions, Agent1Response), (Agent2Name, Agent2Instructions, Agent2Response)); - OpenAIResponseClient responseClient1 = this.CreateResponseClient(Agent1Name); - OpenAIResponseClient responseClient2 = this.CreateResponseClient(Agent2Name); + ResponsesClient responseClient1 = this.CreateResponseClient(Agent1Name); + ResponsesClient responseClient2 = this.CreateResponseClient(Agent2Name); // Act - OpenAIResponse response1 = await responseClient1.CreateResponseAsync("Hello"); - OpenAIResponse response2 = await responseClient2.CreateResponseAsync("Hello"); + ResponseResult response1 = await responseClient1.CreateResponseAsync("Hello"); + ResponseResult response2 = await responseClient2.CreateResponseAsync("Hello"); // Assert string content1 = response1.GetOutputText(); @@ -190,10 +190,10 @@ public async Task CreateResponse_SameAgentStreamingAndNonStreaming_BothWorkCorre const string ExpectedResponse = "This is the response"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - Non-streaming - OpenAIResponse nonStreamingResponse = await responseClient.CreateResponseAsync("Test"); + ResponseResult nonStreamingResponse = await responseClient.CreateResponseAsync("Test"); // Act - Streaming AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -224,10 +224,10 @@ public async Task CreateResponse_CompletedResponse_HasCorrectStatusAsync() const string ExpectedResponse = "Complete"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Test"); + ResponseResult response = await responseClient.CreateResponseAsync("Test"); // Assert Assert.Equal(ResponseStatus.Completed, response.Status); @@ -247,7 +247,7 @@ public async Task CreateResponseStreaming_VerifyEventSequence_ContainsExpectedEv const string ExpectedResponse = "Test response with multiple words"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -286,7 +286,7 @@ public async Task CreateResponseStreaming_EmptyResponse_HandlesGracefullyAsync() const string ExpectedResponse = ""; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -316,10 +316,10 @@ public async Task CreateResponse_IncludesMetadata_HasRequiredFieldsAsync() const string ExpectedResponse = "Response with metadata"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Test"); + ResponseResult response = await responseClient.CreateResponseAsync("Test"); // Assert Assert.NotNull(response.Id); @@ -340,7 +340,7 @@ public async Task CreateResponseStreaming_LongText_StreamsAllContentAsync() string expectedResponse = string.Join(" ", Enumerable.Range(1, 100).Select(i => $"Word{i}")); this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, expectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Generate long text"); @@ -371,7 +371,7 @@ public async Task CreateResponseStreaming_OutputIndices_AreConsistentAsync() const string ExpectedResponse = "Test output index"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -407,7 +407,7 @@ public async Task CreateResponseStreaming_SingleWord_StreamsCorrectlyAsync() const string ExpectedResponse = "Hello"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -437,7 +437,7 @@ public async Task CreateResponseStreaming_SpecialCharacters_PreservesFormattingA const string ExpectedResponse = "Hello! How are you? I'm fine. 100% great!"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -467,10 +467,10 @@ public async Task CreateResponse_SpecialCharacters_PreservesContentAsync() const string ExpectedResponse = "Symbols: @#$%^&*() Quotes: \"Hello\" 'World' Unicode: 你好 🌍"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Test"); + ResponseResult response = await responseClient.CreateResponseAsync("Test"); // Assert string content = response.GetOutputText(); @@ -489,7 +489,7 @@ public async Task CreateResponseStreaming_ItemIds_AreConsistentAsync() const string ExpectedResponse = "Testing item IDs"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -525,12 +525,12 @@ public async Task CreateResponse_MultipleSequentialRequests_AllSucceedAsync() const string ExpectedResponse = "Response"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act & Assert - Make 5 sequential requests for (int i = 0; i < 5; i++) { - OpenAIResponse response = await responseClient.CreateResponseAsync($"Request {i}"); + ResponseResult response = await responseClient.CreateResponseAsync($"Request {i}"); Assert.NotNull(response); Assert.Equal(ResponseStatus.Completed, response.Status); Assert.Equal(ExpectedResponse, response.GetOutputText()); @@ -549,7 +549,7 @@ public async Task CreateResponseStreaming_MultipleSequentialRequests_AllStreamCo const string ExpectedResponse = "Streaming response"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act & Assert - Make 3 sequential streaming requests for (int i = 0; i < 3; i++) @@ -581,13 +581,13 @@ public async Task CreateResponse_MultipleRequests_GenerateUniqueIdsAsync() const string ExpectedResponse = "Response"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act List responseIds = []; for (int i = 0; i < 10; i++) { - OpenAIResponse response = await responseClient.CreateResponseAsync($"Request {i}"); + ResponseResult response = await responseClient.CreateResponseAsync($"Request {i}"); responseIds.Add(response.Id); } @@ -608,7 +608,7 @@ public async Task CreateResponseStreaming_SequenceNumbers_AreMonotonicallyIncrea const string ExpectedResponse = "Test sequence numbers with multiple words"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -641,10 +641,10 @@ public async Task CreateResponse_ModelInformation_IsCorrectAsync() const string ExpectedResponse = "Test model info"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Test"); + ResponseResult response = await responseClient.CreateResponseAsync("Test"); // Assert Assert.NotNull(response.Model); @@ -663,7 +663,7 @@ public async Task CreateResponseStreaming_Punctuation_PreservesContentAsync() const string ExpectedResponse = "Hello, world! How are you today? I'm doing well."; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -693,10 +693,10 @@ public async Task CreateResponse_ShortInput_ReturnsValidResponseAsync() const string ExpectedResponse = "OK"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Hi"); + ResponseResult response = await responseClient.CreateResponseAsync("Hi"); // Assert Assert.NotNull(response); @@ -716,7 +716,7 @@ public async Task CreateResponseStreaming_ContentIndices_AreConsistentAsync() const string ExpectedResponse = "Test content indices"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -748,10 +748,10 @@ public async Task CreateResponse_Newlines_PreservesFormattingAsync() const string ExpectedResponse = "Line 1\nLine 2\nLine 3"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Test"); + ResponseResult response = await responseClient.CreateResponseAsync("Test"); // Assert string content = response.GetOutputText(); @@ -771,7 +771,7 @@ public async Task CreateResponseStreaming_Newlines_PreservesFormattingAsync() const string ExpectedResponse = "First line\nSecond line\nThird line"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -807,10 +807,10 @@ public async Task CreateResponse_ImageContent_ReturnsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.ImageContentMockChatClient(ImageUrl)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Show me an image"); + ResponseResult response = await responseClient.CreateResponseAsync("Show me an image"); // Assert Assert.NotNull(response); @@ -834,7 +834,7 @@ public async Task CreateResponseStreaming_ImageContent_StreamsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.ImageContentMockChatClient(ImageUrl)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Show me an image"); @@ -868,10 +868,10 @@ public async Task CreateResponse_AudioContent_ReturnsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.AudioContentMockChatClient(AudioData, Transcript)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Generate audio"); + ResponseResult response = await responseClient.CreateResponseAsync("Generate audio"); // Assert Assert.NotNull(response); @@ -896,7 +896,7 @@ public async Task CreateResponseStreaming_AudioContent_StreamsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.AudioContentMockChatClient(AudioData, Transcript)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Generate audio"); @@ -930,10 +930,10 @@ public async Task CreateResponse_FunctionCall_ReturnsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.FunctionCallMockChatClient(FunctionName, Arguments)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("What's the weather?"); + ResponseResult response = await responseClient.CreateResponseAsync("What's the weather?"); // Assert Assert.NotNull(response); @@ -957,7 +957,7 @@ public async Task CreateResponseStreaming_FunctionCall_StreamsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.FunctionCallMockChatClient(FunctionName, Arguments)); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Calculate 2+2"); @@ -988,10 +988,10 @@ public async Task CreateResponse_MixedContent_ReturnsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.MixedContentMockChatClient()); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act - OpenAIResponse response = await responseClient.CreateResponseAsync("Show me various content"); + ResponseResult response = await responseClient.CreateResponseAsync("Show me various content"); // Assert Assert.NotNull(response); @@ -1014,7 +1014,7 @@ public async Task CreateResponseStreaming_MixedContent_StreamsCorrectlyAsync() instructions: Instructions, chatClient: new TestHelpers.MixedContentMockChatClient()); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Show me various content"); @@ -1047,7 +1047,7 @@ public async Task CreateResponseStreaming_TextDone_IncludesDoneEventAsync() const string ExpectedResponse = "Complete text response"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -1075,7 +1075,7 @@ public async Task CreateResponseStreaming_ContentPartAdded_IncludesEventAsync() const string ExpectedResponse = "Response with content parts"; this._httpClient = await this.CreateTestServerAsync(AgentName, Instructions, ExpectedResponse); - OpenAIResponseClient responseClient = this.CreateResponseClient(AgentName); + ResponsesClient responseClient = this.CreateResponseClient(AgentName); // Act AsyncCollectionResult streamingResult = responseClient.CreateResponseStreamingAsync("Test"); @@ -1122,7 +1122,7 @@ public async Task CreateResponse_WithConversationId_DoesNotForwardConversationId string conversationId = convDoc.RootElement.GetProperty("id").GetString()!; // Act - Send request with conversation ID using raw HTTP - // (OpenAI SDK doesn't expose ConversationId directly on ResponseCreationOptions) + // (OpenAI SDK doesn't expose ConversationId directly on CreateResponseOptions) var requestBody = new { input = "Test", @@ -1201,9 +1201,9 @@ public async Task CreateResponseStreaming_WithConversationId_DoesNotForwardConve Assert.Null(mockChatClient.LastChatOptions.ConversationId); } - private OpenAIResponseClient CreateResponseClient(string agentName) + private ResponsesClient CreateResponseClient(string agentName) { - return new OpenAIResponseClient( + return new ResponsesClient( model: "test-model", credential: new ApiKeyCredential("test-api-key"), options: new OpenAIClientOptions diff --git a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs index 781ccb123e..127fe1a58f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.OpenAI.UnitTests/Extensions/OpenAIResponseClientExtensionsTests.cs @@ -55,9 +55,9 @@ public async IAsyncEnumerable GetStreamingResponseAsync( } /// - /// Creates a test OpenAIResponseClient implementation for testing. + /// Creates a test ResponsesClient implementation for testing. /// - private sealed class TestOpenAIResponseClient : OpenAIResponseClient + private sealed class TestOpenAIResponseClient : ResponsesClient { public TestOpenAIResponseClient() { @@ -147,7 +147,7 @@ public void CreateAIAgent_WithNullClient_ThrowsArgumentNullException() { // Act & Assert var exception = Assert.Throws(() => - ((OpenAIResponseClient)null!).CreateAIAgent()); + ((ResponsesClient)null!).CreateAIAgent()); Assert.Equal("client", exception.ParamName); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs index 32adb93ddb..d5976a3174 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.IntegrationTests/MediaInputTest.cs @@ -24,7 +24,7 @@ public sealed class MediaInputTest(ITestOutputHelper output) : IntegrationTest(o private const string ImageReference = "https://sample-files.com/downloads/images/jpg/web_optimized_1200x800_97kb.jpg"; [Theory] - [InlineData(ImageReference, "image/jpeg")] + [InlineData(ImageReference, "image/jpeg", Skip = "Failing consistently in the agent service api")] [InlineData(PdfReference, "application/pdf", Skip = "Not currently supported by agent service api")] public async Task ValidateFileUrlAsync(string fileSource, string mediaType) { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeMapSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeMapSmokeTests.cs index 7f463c0373..5ea4715680 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeMapSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeMapSmokeTests.cs @@ -21,7 +21,7 @@ public async Task Test_EdgeMap_MaintainsFanInEdgeStateAsync() Dictionary> workflowEdges = []; - FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3", new EdgeId(0)); + FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3", new EdgeId(0), null); Edge fanInEdge = new(edgeData); workflowEdges["executor1"] = [fanInEdge]; diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs index 4239b6ef6d..99cd46dd4b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs @@ -155,7 +155,7 @@ public async Task Test_FanInEdgeRunnerAsync() runContext.Executors["executor2"] = new ForwardMessageExecutor("executor2"); runContext.Executors["executor3"] = new ForwardMessageExecutor("executor3"); - FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3", new EdgeId(0)); + FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3", new EdgeId(0), null); FanInEdgeRunner runner = new(runContext, edgeData); // Step 1: Send message from executor1, should not forward yet. diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/JsonSerializationTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/JsonSerializationTests.cs index f55185a78f..686cdea308 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/JsonSerializationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/JsonSerializationTests.cs @@ -118,7 +118,7 @@ public void Test_FanOutEdgeInfo_JsonRoundtrip() RunJsonRoundtrip(TestFanOutEdgeInfo_Assigner, predicate: TestFanOutEdgeInfo_Assigner.CreateValidator()); } - private static FanInEdgeData TestFanInEdgeData => new(["SourceExecutor1", "SourceExecutor2"], "TargetExecutor", TakeEdgeId()); + private static FanInEdgeData TestFanInEdgeData => new(["SourceExecutor1", "SourceExecutor2"], "TargetExecutor", TakeEdgeId(), null); private static FanInEdgeInfo TestFanInEdgeInfo => new(TestFanInEdgeData); [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs index 9cf460e658..1878a55868 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs @@ -137,17 +137,17 @@ public void Test_EdgeInfos() RunEdgeInfoMatchTest(fanOutEdgeWithAssigner); // FanIn Edges - Edge fanInEdge = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1), TakeEdgeId())); + Edge fanInEdge = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1), TakeEdgeId(), null)); RunEdgeInfoMatchTest(fanInEdge); - Edge fanInEdge2 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1), TakeEdgeId())); + Edge fanInEdge2 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1), TakeEdgeId(), null)); RunEdgeInfoMatchTest(fanInEdge, fanInEdge2); - Edge fanInEdge3 = new(new FanInEdgeData([Source(2), Source(3), Source(1)], Sink(1), TakeEdgeId())); + Edge fanInEdge3 = new(new FanInEdgeData([Source(2), Source(3), Source(1)], Sink(1), TakeEdgeId(), null)); RunEdgeInfoMatchTest(fanInEdge, fanInEdge3, expect: false); // Order matters (though for FanIn maybe it shouldn't?) - Edge fanInEdge4 = new(new FanInEdgeData([Source(1), Source(2), Source(4)], Sink(1), TakeEdgeId())); - Edge fanInEdge5 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(2), TakeEdgeId())); + Edge fanInEdge4 = new(new FanInEdgeData([Source(1), Source(2), Source(4)], Sink(1), TakeEdgeId(), null)); + Edge fanInEdge5 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(2), TakeEdgeId(), null)); RunEdgeInfoMatchTest(fanInEdge, fanInEdge4, expect: false); // Identity matters RunEdgeInfoMatchTest(fanInEdge, fanInEdge5, expect: false); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs index a0e57006ed..16a51876d0 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs @@ -57,7 +57,7 @@ internal sealed class HelloAgent(string id = nameof(HelloAgent)) : AIAgent public const string Greeting = "Hello World!"; public const string DefaultId = nameof(HelloAgent); - public override string Id => id; + protected override string? IdCore => id; public override string? Name => id; public override AgentThread GetNewThread() diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs index b93d7862d5..daff2c248e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs @@ -19,7 +19,7 @@ public class SpecializedExecutorSmokeTests { public class TestAIAgent(List? messages = null, string? id = null, string? name = null) : AIAgent { - public override string Id => id ?? base.Id; + protected override string? IdCore => id; public override string? Name => name; public static List ToChatMessages(params string[] messages) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs index 4966585211..9ddc94cf71 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs @@ -13,7 +13,7 @@ namespace Microsoft.Agents.AI.Workflows.UnitTests; internal class TestEchoAgent(string? id = null, string? name = null, string? prefix = null) : AIAgent { - public override string Id => id ?? base.Id; + protected override string? IdCore => id; public override string? Name => name ?? base.Name; public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowVisualizerTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowVisualizerTests.cs index 4e7aa51ea0..447c52a66e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowVisualizerTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowVisualizerTests.cs @@ -394,4 +394,61 @@ public void Test_WorkflowViz_Mermaid_Mixed_EdgeTypes() // Check fan-in (should have intermediate node) mermaidContent.Should().Contain("((fan-in))"); } + + [Fact] + public void Test_WorkflowViz_Mermaid_Edge_Label_With_Pipe() + { + // Test that pipe characters in labels are properly escaped + var start = new MockExecutor("start"); + var end = new MockExecutor("end"); + + var workflow = new WorkflowBuilder("start") + .AddEdge(start, end, label: "High | Low Priority") + .Build(); + + var mermaidContent = workflow.ToMermaidString(); + + // Should escape pipe character + mermaidContent.Should().Contain("start -->|High | Low Priority| end"); + // Should not contain unescaped pipe that would break syntax + mermaidContent.Should().NotContain("-->|High | Low"); + } + + [Fact] + public void Test_WorkflowViz_Mermaid_Edge_Label_With_Special_Chars() + { + // Test that special characters are properly escaped + var start = new MockExecutor("start"); + var end = new MockExecutor("end"); + + var workflow = new WorkflowBuilder("start") + .AddEdge(start, end, label: "Score >= 90 & < 100") + .Build(); + + var mermaidContent = workflow.ToMermaidString(); + + // Should escape special characters + mermaidContent.Should().Contain("&"); + mermaidContent.Should().Contain(">"); + mermaidContent.Should().Contain("<"); + } + + [Fact] + public void Test_WorkflowViz_Mermaid_Edge_Label_With_Newline() + { + // Test that newlines are converted to
+ var start = new MockExecutor("start"); + var end = new MockExecutor("end"); + + var workflow = new WorkflowBuilder("start") + .AddEdge(start, end, label: "Line 1\nLine 2") + .Build(); + + var mermaidContent = workflow.ToMermaidString(); + + // Should convert newline to
+ mermaidContent.Should().Contain("Line 1
Line 2"); + // Should not contain literal newline in the label (but the overall output has newlines between statements) + mermaidContent.Should().NotContain("Line 1\nLine 2"); + } } diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunStreamingTests.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunStreamingTests.cs index 669a4dd2a0..80a148d7fc 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunStreamingTests.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunStreamingTests.cs @@ -3,11 +3,11 @@ using System.Threading.Tasks; using AgentConformance.IntegrationTests; -namespace OpenAIResponse.IntegrationTests; +namespace ResponseResult.IntegrationTests; public class OpenAIResponseStoreTrueChatClientAgentRunStreamingTests() : ChatClientAgentRunStreamingTests(() => new(store: true)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() => @@ -16,7 +16,7 @@ public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() public class OpenAIResponseStoreFalseChatClientAgentRunStreamingTests() : ChatClientAgentRunStreamingTests(() => new(store: false)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() => diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunTests.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunTests.cs index af2f1c14ec..8b742e2964 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunTests.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseChatClientAgentRunTests.cs @@ -3,11 +3,11 @@ using System.Threading.Tasks; using AgentConformance.IntegrationTests; -namespace OpenAIResponse.IntegrationTests; +namespace ResponseResult.IntegrationTests; public class OpenAIResponseStoreTrueChatClientAgentRunTests() : ChatClientAgentRunTests(() => new(store: true)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() => @@ -16,7 +16,7 @@ public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() public class OpenAIResponseStoreFalseChatClientAgentRunTests() : ChatClientAgentRunTests(() => new(store: false)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithInstructionsAndNoMessageReturnsExpectedResultAsync() => diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs index a58583fbca..c6c84db569 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs @@ -12,13 +12,13 @@ using OpenAI.Responses; using Shared.IntegrationTests; -namespace OpenAIResponse.IntegrationTests; +namespace ResponseResult.IntegrationTests; public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture { private static readonly OpenAIConfiguration s_config = TestConfiguration.LoadSection(); - private OpenAIResponseClient _openAIResponseClient = null!; + private ResponsesClient _openAIResponseClient = null!; private ChatClientAgent _agent = null!; public AIAgent Agent => this._agent; @@ -77,7 +77,7 @@ public async Task CreateChatClientAgentAsync( { Instructions = instructions, Tools = aiTools, - RawRepresentationFactory = new Func(_ => new ResponseCreationOptions() { StoredOutputEnabled = store }) + RawRepresentationFactory = new Func(_ => new CreateResponseOptions() { StoredOutputEnabled = store }) }, }); @@ -92,7 +92,7 @@ public Task DeleteThreadAsync(AgentThread thread) => public async Task InitializeAsync() { this._openAIResponseClient = new OpenAIClient(s_config.ApiKey) - .GetOpenAIResponseClient(s_config.ChatModelId); + .GetResponsesClient(s_config.ChatModelId); this._agent = await this.CreateChatClientAgentAsync(); } diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunStreamingTests.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunStreamingTests.cs index e2e7e28bbd..c12f8f2db5 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunStreamingTests.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunStreamingTests.cs @@ -3,11 +3,11 @@ using System.Threading.Tasks; using AgentConformance.IntegrationTests; -namespace OpenAIResponse.IntegrationTests; +namespace ResponseResult.IntegrationTests; public class OpenAIResponseStoreTrueRunStreamingTests() : RunStreamingTests(() => new(store: true)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithNoMessageDoesNotFailAsync() => Task.CompletedTask; @@ -15,7 +15,7 @@ public override Task RunWithNoMessageDoesNotFailAsync() => public class OpenAIResponseStoreFalseRunStreamingTests() : RunStreamingTests(() => new(store: false)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithNoMessageDoesNotFailAsync() => diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunTests.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunTests.cs index 41c5254474..423ac583c7 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunTests.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseRunTests.cs @@ -3,11 +3,11 @@ using System.Threading.Tasks; using AgentConformance.IntegrationTests; -namespace OpenAIResponse.IntegrationTests; +namespace ResponseResult.IntegrationTests; public class OpenAIResponseStoreTrueRunTests() : RunTests(() => new(store: true)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithNoMessageDoesNotFailAsync() => Task.CompletedTask; @@ -15,7 +15,7 @@ public override Task RunWithNoMessageDoesNotFailAsync() => public class OpenAIResponseStoreFalseRunTests() : RunTests(() => new(store: false)) { - private const string SkipReason = "OpenAIResponse does not support empty messages"; + private const string SkipReason = "ResponseResult does not support empty messages"; [Fact(Skip = SkipReason)] public override Task RunWithNoMessageDoesNotFailAsync() => diff --git a/python/.env.example b/python/.env.example index f864f18f72..c09300d775 100644 --- a/python/.env.example +++ b/python/.env.example @@ -33,7 +33,6 @@ ANTHROPIC_MODEL="" OLLAMA_ENDPOINT="" OLLAMA_MODEL="" # Observability -ENABLE_OTEL=true +ENABLE_INSTRUMENTATION=true ENABLE_SENSITIVE_DATA=true -OTLP_ENDPOINT="http://localhost:4317/" -# APPLICATIONINSIGHTS_CONNECTION_STRING="..." +OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4317/" diff --git a/python/CHANGELOG.md b/python/CHANGELOG.md index 54171d870a..7b012ccf23 100644 --- a/python/CHANGELOG.md +++ b/python/CHANGELOG.md @@ -7,6 +7,51 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- **agent-framework-azurefunctions**: Durable Agents: platforms should use consistent entity method names (#2234) + +## [1.0.0b251216] - 2025-12-16 + +### Added + +- **agent-framework-ollama**: Ollama connector for Agent Framework (#1104) +- **agent-framework-core**: Added custom args and thread object to `ai_function` kwargs (#2769) +- **agent-framework-core**: Enable checkpointing for `WorkflowAgent` (#2774) + +### Changed + +- **agent-framework-core**: [BREAKING] Observability updates (#2782) +- **agent-framework-core**: Use agent description in `HandoffBuilder` auto-generated tools (#2714) +- **agent-framework-core**: Remove warnings from workflow builder when not using factories (#2808) + +### Fixed + +- **agent-framework-core**: Fix `WorkflowAgent` to include thread conversation history (#2774) +- **agent-framework-core**: Fix context duplication in handoff workflows when restoring from checkpoint (#2867) +- **agent-framework-core**: Fix middleware terminate flag to exit function calling loop immediately (#2868) +- **agent-framework-core**: Fix `WorkflowAgent` to emit `yield_output` as agent response (#2866) +- **agent-framework-core**: Filter framework kwargs from MCP tool invocations (#2870) + +## [1.0.0b251211] - 2025-12-11 + +### Added + +- **agent-framework-core**: Extend HITL support for all orchestration patterns (#2620) +- **agent-framework-core**: Add factory pattern to concurrent orchestration builder (#2738) +- **agent-framework-core**: Add factory pattern to sequential orchestration builder (#2710) +- **agent-framework-azure-ai**: Capture file IDs from code interpreter in streaming responses (#2741) + +### Changed + +- **agent-framework-azurefunctions**: Change DurableAIAgent log level from warning to debug when invoked without thread (#2736) + +### Fixed + +- **agent-framework-core**: Added more complete parsing for mcp tool arguments (#2756) +- **agent-framework-core**: Fix GroupChat ManagerSelectionResponse JSON Schema for OpenAI Structured Outputs (#2750) +- **samples**: Standardize OpenAI API key environment variable naming (#2629) + ## [1.0.0b251209] - 2025-12-09 ### Added @@ -347,7 +392,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 For more information, see the [announcement blog post](https://devblogs.microsoft.com/foundry/introducing-microsoft-agent-framework-the-open-source-engine-for-agentic-ai-apps/). -[Unreleased]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251209...HEAD +[Unreleased]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251216...HEAD +[1.0.0b251216]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251211...python-1.0.0b251216 +[1.0.0b251211]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251209...python-1.0.0b251211 [1.0.0b251209]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251204...python-1.0.0b251209 [1.0.0b251204]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251120...python-1.0.0b251204 [1.0.0b251120]: https://github.com/microsoft/agent-framework/compare/python-1.0.0b251117...python-1.0.0b251120 diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 4f86eb5afc..f90d7214a0 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -5,7 +5,7 @@ import re import uuid from collections.abc import AsyncIterable, Sequence -from typing import Any, cast +from typing import Any, Final, cast import httpx from a2a.client import Client, ClientConfig, ClientFactory, minimal_agent_card @@ -38,6 +38,7 @@ UriContent, prepend_agent_framework_to_user_agent, ) +from agent_framework.observability import use_agent_instrumentation __all__ = ["A2AAgent"] @@ -58,6 +59,7 @@ def _get_uri_data(uri: str) -> str: return match.group("base64_data") +@use_agent_instrumentation class A2AAgent(BaseAgent): """Agent2Agent (A2A) protocol implementation. @@ -69,6 +71,8 @@ class A2AAgent(BaseAgent): Can be initialized with a URL, AgentCard, or existing A2A Client instance. """ + AGENT_PROVIDER_NAME: Final[str] = "A2A" + def __init__( self, *, diff --git a/python/packages/a2a/pyproject.toml b/python/packages/a2a/pyproject.toml index 0233d66628..56d79ce7fe 100644 --- a/python/packages/a2a/pyproject.toml +++ b/python/packages/a2a/pyproject.toml @@ -4,7 +4,7 @@ description = "A2A integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251209" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index ab7eb53940..db2f160a9d 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -23,7 +23,7 @@ from agent_framework._middleware import use_chat_middleware from agent_framework._tools import use_function_invocation from agent_framework._types import BaseContent, Contents -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from ._event_converters import AGUIEventConverter from ._http_service import AGUIHttpService @@ -89,7 +89,7 @@ async def response_wrapper(self, *args: Any, **kwargs: Any) -> ChatResponse: @_apply_server_function_call_unwrap @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AGUIChatClient(BaseChatClient): """Chat client for communicating with AG-UI compliant servers. diff --git a/python/packages/ag-ui/pyproject.toml b/python/packages/ag-ui/pyproject.toml index 0a501b16be..8a4adceeee 100644 --- a/python/packages/ag-ui/pyproject.toml +++ b/python/packages/ag-ui/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "agent-framework-ag-ui" -version = "1.0.0b251209" +version = "1.0.0b251216" description = "AG-UI protocol integration for Agent Framework" readme = "README.md" license-files = ["LICENSE"] diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index 96a70bc4a0..e4eca2d005 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -35,7 +35,7 @@ ) from agent_framework._pydantic import AFBaseSettings from agent_framework.exceptions import ServiceInitializationError -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from anthropic import AsyncAnthropic from anthropic.types.beta import ( BetaContentBlock, @@ -110,7 +110,7 @@ class AnthropicSettings(AFBaseSettings): @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AnthropicClient(BaseChatClient): """Anthropic Chat client.""" diff --git a/python/packages/anthropic/pyproject.toml b/python/packages/anthropic/pyproject.toml index 272a612c0a..55ef501a9f 100644 --- a/python/packages/anthropic/pyproject.toml +++ b/python/packages/anthropic/pyproject.toml @@ -4,7 +4,7 @@ description = "Anthropic integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251209" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/azure-ai-search/pyproject.toml b/python/packages/azure-ai-search/pyproject.toml index cda63bf6ca..d4227e3dd8 100644 --- a/python/packages/azure-ai-search/pyproject.toml +++ b/python/packages/azure-ai-search/pyproject.toml @@ -4,7 +4,7 @@ description = "Azure AI Search integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251209" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index b4c55f0e55..839687fbaf 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -43,7 +43,7 @@ use_function_invocation, ) from agent_framework.exceptions import ServiceInitializationError, ServiceResponseException -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from azure.ai.agents.aio import AgentsClient from azure.ai.agents.models import ( Agent, @@ -63,6 +63,8 @@ McpTool, MessageDeltaChunk, MessageDeltaTextContent, + MessageDeltaTextFileCitationAnnotation, + MessageDeltaTextFilePathAnnotation, MessageDeltaTextUrlCitationAnnotation, MessageImageUrlParam, MessageInputContentBlock, @@ -105,7 +107,7 @@ @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AzureAIAgentClient(BaseChatClient): """Azure AI Agent Chat client.""" @@ -471,6 +473,45 @@ def _extract_url_citations( return url_citations + def _extract_file_path_contents(self, message_delta_chunk: MessageDeltaChunk) -> list[HostedFileContent]: + """Extract file references from MessageDeltaChunk annotations. + + Code interpreter generates files that are referenced via file path or file citation + annotations in the message content. This method extracts those file IDs and returns + them as HostedFileContent objects. + + Handles two annotation types: + - MessageDeltaTextFilePathAnnotation: Contains file_path.file_id + - MessageDeltaTextFileCitationAnnotation: Contains file_citation.file_id + + Args: + message_delta_chunk: The message delta chunk to process + + Returns: + List of HostedFileContent objects for any files referenced in annotations + """ + file_contents: list[HostedFileContent] = [] + + for content in message_delta_chunk.delta.content: + if isinstance(content, MessageDeltaTextContent) and content.text and content.text.annotations: + for annotation in content.text.annotations: + if isinstance(annotation, MessageDeltaTextFilePathAnnotation): + # Extract file_id from the file_path annotation + file_path = getattr(annotation, "file_path", None) + if file_path is not None: + file_id = getattr(file_path, "file_id", None) + if file_id: + file_contents.append(HostedFileContent(file_id=file_id)) + elif isinstance(annotation, MessageDeltaTextFileCitationAnnotation): + # Extract file_id from the file_citation annotation + file_citation = getattr(annotation, "file_citation", None) + if file_citation is not None: + file_id = getattr(file_citation, "file_id", None) + if file_id: + file_contents.append(HostedFileContent(file_id=file_id)) + + return file_contents + def _get_real_url_from_citation_reference( self, citation_url: str, azure_search_tool_calls: list[dict[str, Any]] ) -> str: @@ -530,6 +571,9 @@ async def _process_stream( # Extract URL citations from the delta chunk url_citations = self._extract_url_citations(event_data, azure_search_tool_calls) + # Extract file path contents from code interpreter outputs + file_contents = self._extract_file_path_contents(event_data) + # Create contents with citations if any exist citation_content: list[Contents] = [] if event_data.text or url_citations: @@ -538,6 +582,9 @@ async def _process_stream( text_content_obj.annotations = url_citations citation_content.append(text_content_obj) + # Add file contents from file path annotations + citation_content.extend(file_contents) + yield ChatResponseUpdate( role=role, contents=citation_content if citation_content else None, diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 63bd2b27df..f4d5328f03 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -15,7 +15,7 @@ use_function_invocation, ) from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from agent_framework.openai._responses_client import OpenAIBaseResponsesClient from azure.ai.projects.aio import AIProjectClient from azure.ai.projects.models import ( @@ -49,7 +49,7 @@ @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AzureAIClient(OpenAIBaseResponsesClient): """Azure AI Agent client.""" @@ -164,27 +164,94 @@ def __init__( # Track whether we should close client connection self._should_close_client = should_close_client - async def setup_azure_ai_observability(self, enable_sensitive_data: bool | None = None) -> None: - """Use this method to setup tracing in your Azure AI Project. + async def configure_azure_monitor( + self, + enable_sensitive_data: bool = False, + **kwargs: Any, + ) -> None: + """Setup observability with Azure Monitor (Azure AI Foundry integration). + + This method configures Azure Monitor for telemetry collection using the + connection string from the Azure AI project client. - This will take the connection string from the project project_client. - It will override any connection string that is set in the environment variables. - It will disable any OTLP endpoint that might have been set. + Args: + enable_sensitive_data: Enable sensitive data logging (prompts, responses). + Should only be enabled in development/test environments. Default is False. + **kwargs: Additional arguments passed to configure_azure_monitor(). + Common options include: + - enable_live_metrics (bool): Enable Azure Monitor Live Metrics + - credential (TokenCredential): Azure credential for Entra ID auth + - resource (Resource): Custom OpenTelemetry resource + See https://learn.microsoft.com/python/api/azure-monitor-opentelemetry/azure.monitor.opentelemetry.configure_azure_monitor + for full list of options. + + Raises: + ImportError: If azure-monitor-opentelemetry-exporter is not installed. + + Examples: + .. code-block:: python + + from agent_framework.azure import AzureAIClient + from azure.ai.projects.aio import AIProjectClient + from azure.identity.aio import DefaultAzureCredential + + async with ( + DefaultAzureCredential() as credential, + AIProjectClient( + endpoint="https://your-project.api.azureml.ms", credential=credential + ) as project_client, + AzureAIClient(project_client=project_client) as client, + ): + # Setup observability with defaults + await client.configure_azure_monitor() + + # With live metrics enabled + await client.configure_azure_monitor(enable_live_metrics=True) + + # With sensitive data logging (dev/test only) + await client.configure_azure_monitor(enable_sensitive_data=True) + + Note: + This method retrieves the Application Insights connection string from the + Azure AI project client automatically. You must have Application Insights + configured in your Azure AI project for this to work. """ + # Get connection string from project client try: conn_string = await self.project_client.telemetry.get_application_insights_connection_string() except ResourceNotFoundError: logger.warning( - "No Application Insights connection string found for the Azure AI Project, " - "please call setup_observability() manually." + "No Application Insights connection string found for the Azure AI Project. " + "Please ensure Application Insights is configured in your Azure AI project, " + "or call configure_otel_providers() manually with custom exporters." ) return - from agent_framework.observability import setup_observability - setup_observability( - applicationinsights_connection_string=conn_string, enable_sensitive_data=enable_sensitive_data + # Import Azure Monitor with proper error handling + try: + from azure.monitor.opentelemetry import configure_azure_monitor + except ImportError as exc: + raise ImportError( + "azure-monitor-opentelemetry is required for Azure Monitor integration. " + "Install it with: pip install azure-monitor-opentelemetry" + ) from exc + + from agent_framework.observability import create_metric_views, create_resource, enable_instrumentation + + # Create resource if not provided in kwargs + if "resource" not in kwargs: + kwargs["resource"] = create_resource() + + # Configure Azure Monitor with connection string and kwargs + configure_azure_monitor( + connection_string=conn_string, + views=create_metric_views(), + **kwargs, ) + # Complete setup with core observability + enable_instrumentation(enable_sensitive_data=enable_sensitive_data) + async def __aenter__(self) -> "Self": """Async context manager entry.""" return self @@ -268,6 +335,10 @@ async def _get_agent_reference_or_create( if "tools" in run_options: args["tools"] = run_options["tools"] + if "temperature" in run_options: + args["temperature"] = run_options["temperature"] + if "top_p" in run_options: + args["top_p"] = run_options["top_p"] if "response_format" in run_options: response_format = run_options["response_format"] @@ -346,7 +417,7 @@ async def prepare_options( # Remove properties that are not supported on request level # but were configured on agent level - exclude = ["model", "tools", "response_format"] + exclude = ["model", "tools", "response_format", "temperature", "top_p"] for property in exclude: run_options.pop(property, None) diff --git a/python/packages/azure-ai/pyproject.toml b/python/packages/azure-ai/pyproject.toml index 4d85dffdab..685172e2e4 100644 --- a/python/packages/azure-ai/pyproject.toml +++ b/python/packages/azure-ai/pyproject.toml @@ -4,7 +4,7 @@ description = "Azure AI Foundry integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251209" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py index 80b2cebdb3..f1b4dafb63 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py @@ -24,6 +24,7 @@ FunctionCallContent, FunctionResultContent, HostedCodeInterpreterTool, + HostedFileContent, HostedFileSearchTool, HostedMCPTool, HostedVectorStoreContent, @@ -42,6 +43,8 @@ FileInfo, MessageDeltaChunk, MessageDeltaTextContent, + MessageDeltaTextFileCitationAnnotation, + MessageDeltaTextFilePathAnnotation, MessageDeltaTextUrlCitationAnnotation, RequiredFunctionToolCall, RequiredMcpToolCall, @@ -1362,6 +1365,108 @@ def test_azure_ai_chat_client_extract_url_citations_with_citations(mock_agents_c assert citation.annotated_regions[0].end_index == 20 +def test_azure_ai_chat_client_extract_file_path_contents_with_file_path_annotation( + mock_agents_client: MagicMock, +) -> None: + """Test _extract_file_path_contents with MessageDeltaChunk containing file path annotation.""" + chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") + + # Create mock file_path annotation + mock_file_path = MagicMock() + mock_file_path.file_id = "assistant-test-file-123" + + mock_annotation = MagicMock(spec=MessageDeltaTextFilePathAnnotation) + mock_annotation.file_path = mock_file_path + + # Create mock text content with annotations + mock_text = MagicMock() + mock_text.annotations = [mock_annotation] + + mock_text_content = MagicMock(spec=MessageDeltaTextContent) + mock_text_content.text = mock_text + + # Create mock delta + mock_delta = MagicMock() + mock_delta.content = [mock_text_content] + + # Create mock MessageDeltaChunk + mock_chunk = MagicMock(spec=MessageDeltaChunk) + mock_chunk.delta = mock_delta + + # Call the method + file_contents = chat_client._extract_file_path_contents(mock_chunk) + + # Verify results + assert len(file_contents) == 1 + assert isinstance(file_contents[0], HostedFileContent) + assert file_contents[0].file_id == "assistant-test-file-123" + + +def test_azure_ai_chat_client_extract_file_path_contents_with_file_citation_annotation( + mock_agents_client: MagicMock, +) -> None: + """Test _extract_file_path_contents with MessageDeltaChunk containing file citation annotation.""" + chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") + + # Create mock file_citation annotation + mock_file_citation = MagicMock() + mock_file_citation.file_id = "cfile_test-citation-456" + + mock_annotation = MagicMock(spec=MessageDeltaTextFileCitationAnnotation) + mock_annotation.file_citation = mock_file_citation + + # Create mock text content with annotations + mock_text = MagicMock() + mock_text.annotations = [mock_annotation] + + mock_text_content = MagicMock(spec=MessageDeltaTextContent) + mock_text_content.text = mock_text + + # Create mock delta + mock_delta = MagicMock() + mock_delta.content = [mock_text_content] + + # Create mock MessageDeltaChunk + mock_chunk = MagicMock(spec=MessageDeltaChunk) + mock_chunk.delta = mock_delta + + # Call the method + file_contents = chat_client._extract_file_path_contents(mock_chunk) + + # Verify results + assert len(file_contents) == 1 + assert isinstance(file_contents[0], HostedFileContent) + assert file_contents[0].file_id == "cfile_test-citation-456" + + +def test_azure_ai_chat_client_extract_file_path_contents_empty_annotations( + mock_agents_client: MagicMock, +) -> None: + """Test _extract_file_path_contents with no annotations returns empty list.""" + chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") + + # Create mock text content with no annotations + mock_text = MagicMock() + mock_text.annotations = [] + + mock_text_content = MagicMock(spec=MessageDeltaTextContent) + mock_text_content.text = mock_text + + # Create mock delta + mock_delta = MagicMock() + mock_delta.content = [mock_text_content] + + # Create mock MessageDeltaChunk + mock_chunk = MagicMock(spec=MessageDeltaChunk) + mock_chunk.delta = mock_delta + + # Call the method + file_contents = chat_client._extract_file_path_contents(mock_chunk) + + # Verify results + assert len(file_contents) == 0 + + def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], ) -> str: diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index 2167c1340a..079b93d8c8 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -477,6 +477,30 @@ async def test_azure_ai_client_agent_creation_with_instructions( assert call_args[1]["definition"].instructions == "Message instructions. Option instructions. " +async def test_azure_ai_client_agent_creation_with_additional_args( + mock_project_client: MagicMock, +) -> None: + """Test agent creation with additional arguments.""" + client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent") + + # Mock agent creation response + mock_agent = MagicMock() + mock_agent.name = "test-agent" + mock_agent.version = "1.0" + mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent) + + run_options = {"model": "test-model", "temperature": 0.9, "top_p": 0.8} + messages_instructions = "Message instructions. " + + await client._get_agent_reference_or_create(run_options, messages_instructions) # type: ignore + + # Verify agent was created with provided arguments + call_args = mock_project_client.agents.create_version.call_args + definition = call_args[1]["definition"] + assert definition.temperature == 0.9 + assert definition.top_p == 0.8 + + async def test_azure_ai_client_agent_creation_with_tools( mock_project_client: MagicMock, ) -> None: diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index a1d71f2138..61e678a38f 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -416,7 +416,7 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien request_response_format, ) logger.debug("Signalling entity %s with request: %s", entity_instance_id, run_request) - await client.signal_entity(entity_instance_id, "run_agent", run_request) + await client.signal_entity(entity_instance_id, "run", run_request) logger.debug(f"[HTTP Trigger] Signal sent to entity {session_id}") @@ -497,7 +497,8 @@ def entity_function(context: df.DurableEntityContext) -> None: """Durable entity that manages agent execution and conversation state. Operations: - - run_agent: Execute the agent with a message + - run: Execute the agent with a message + - run_agent: (Deprecated) Execute the agent with a message - reset: Clear conversation history """ entity_handler = create_agent_entity(agent, callback) @@ -639,7 +640,7 @@ async def _handle_mcp_tool_invocation( logger.info("[MCP Tool] Invoking agent '%s' with query: %s", agent_name, query_preview) # Signal entity to run agent - await client.signal_entity(entity_instance_id, "run_agent", run_request) + await client.signal_entity(entity_instance_id, "run", run_request) # Poll for response (similar to HTTP handler) try: diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py index a65b69f046..ec06009d88 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py @@ -46,7 +46,8 @@ class AgentEntity: - Handles tool execution Operations: - - run_agent: Execute the agent with a message + - run: Execute the agent with a message + - run_agent: (Deprecated) Execute the agent with a message - reset: Clear conversation history Attributes: @@ -94,6 +95,22 @@ async def run_agent( self, context: df.DurableEntityContext, request: RunRequest | dict[str, Any] | str, + ) -> AgentRunResponse: + """(Deprecated) Execute the agent with a message directly in the entity. + + Args: + context: Entity context + request: RunRequest object, dict, or string message (for backward compatibility) + + Returns: + AgentRunResponse enriched with execution metadata. + """ + return await self.run(context, request) + + async def run( + self, + context: df.DurableEntityContext, + request: RunRequest | dict[str, Any] | str, ) -> AgentRunResponse: """Execute the agent with a message directly in the entity. @@ -121,7 +138,7 @@ async def run_agent( response_format = run_request.response_format enable_tool_calls = run_request.enable_tool_calls - logger.debug(f"[AgentEntity.run_agent] Received Message: {run_request}") + logger.debug(f"[AgentEntity.run] Received Message: {run_request}") state_request = DurableAgentStateRequest.from_run_request(run_request) self.state.data.conversation_history.append(state_request) @@ -150,7 +167,7 @@ async def run_agent( ) logger.debug( - "[AgentEntity.run_agent] Agent invocation completed - response type: %s", + "[AgentEntity.run] Agent invocation completed - response type: %s", type(agent_run_response).__name__, ) @@ -167,12 +184,12 @@ async def run_agent( state_response = DurableAgentStateResponse.from_run_response(correlation_id, agent_run_response) self.state.data.conversation_history.append(state_response) - logger.debug("[AgentEntity.run_agent] AgentRunResponse stored in conversation history") + logger.debug("[AgentEntity.run] AgentRunResponse stored in conversation history") return agent_run_response except Exception as exc: - logger.exception("[AgentEntity.run_agent] Agent execution failed.") + logger.exception("[AgentEntity.run] Agent execution failed.") # Create error message error_message = ChatMessage( @@ -367,7 +384,7 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None: operation = context.operation_name - if operation == "run_agent": + if operation == "run" or operation == "run_agent": input_data: Any = context.get_input() request: str | dict[str, Any] @@ -377,7 +394,7 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None: # Fall back to treating input as message string request = "" if input_data is None else str(cast(object, input_data)) - result = await entity.run_agent(context, request) + result = await entity.run(context, request) context.set_result(result.to_dict()) elif operation == "reset": diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index dd67fc7ad4..24b1b27368 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -286,7 +286,7 @@ def my_orchestration(context): logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100]) # Call the entity to get the underlying task - entity_task = self.context.call_entity(entity_id, "run_agent", run_request.to_dict()) + entity_task = self.context.call_entity(entity_id, "run", run_request.to_dict()) # Wrap it in an AgentTask that will convert the result to AgentRunResponse agent_task = AgentTask( diff --git a/python/packages/azurefunctions/pyproject.toml b/python/packages/azurefunctions/pyproject.toml index 3a349f8802..ccc00ffc73 100644 --- a/python/packages/azurefunctions/pyproject.toml +++ b/python/packages/azurefunctions/pyproject.toml @@ -4,7 +4,7 @@ description = "Azure Functions integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251209" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/azurefunctions/tests/integration_tests/README.md b/python/packages/azurefunctions/tests/integration_tests/README.md index d9ecb86234..a7f9fadc44 100644 --- a/python/packages/azurefunctions/tests/integration_tests/README.md +++ b/python/packages/azurefunctions/tests/integration_tests/README.md @@ -29,7 +29,7 @@ docker run -d -p 10000:10000 -p 10001:10001 -p 10002:10002 mcr.microsoft.com/azu **Durable Task Scheduler:** ```bash -docker run -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest +docker run -d -p 8080:8080 -p 8082:8082 -e DTS_USE_DYNAMIC_TASK_HUBS=true mcr.microsoft.com/dts/dts-emulator:latest ``` ## Running Tests diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index fc2590e225..1fbfa57e39 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -339,7 +339,7 @@ async def test_entity_run_agent_operation(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Test message", "thread_id": "test-conv-123", "correlationId": "corr-app-entity-1"}, ) @@ -359,7 +359,7 @@ async def test_entity_stores_conversation_history(self) -> None: mock_context = Mock() # Send first message - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-app-entity-2"} ) @@ -368,7 +368,7 @@ async def test_entity_stores_conversation_history(self) -> None: assert len(history) == 1 # Just the user message # Send second message - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-2", "correlationId": "corr-app-entity-2b"} ) @@ -399,12 +399,12 @@ async def test_entity_increments_message_count(self) -> None: assert len(entity.state.data.conversation_history) == 0 - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-app-entity-3a"} ) assert len(entity.state.data.conversation_history) == 2 - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-app-entity-3b"} ) assert len(entity.state.data.conversation_history) == 4 @@ -434,8 +434,36 @@ def test_create_agent_entity_returns_function(self) -> None: assert callable(entity_function) + def test_entity_function_handles_run_operation(self) -> None: + """Test that the entity function handles the run operation.""" + mock_agent = Mock() + mock_agent.run = AsyncMock( + return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")]) + ) + + entity_function = create_agent_entity(mock_agent) + + # Mock context + mock_context = Mock() + mock_context.operation_name = "run" + mock_context.get_input.return_value = { + "message": "Test message", + "thread_id": "conv-123", + "correlationId": "corr-app-factory-1", + } + mock_context.get_state.return_value = None + + # Execute entity function + entity_function(mock_context) + + # Verify result was set + assert mock_context.set_result.called + assert mock_context.set_state.called + result_call = mock_context.set_result.call_args[0][0] + assert "error" not in result_call + def test_entity_function_handles_run_agent_operation(self) -> None: - """Test that the entity function handles the run_agent operation.""" + """Test that the entity function handles the deprecated run_agent operation for backward compatibility.""" mock_agent = Mock() mock_agent.run = AsyncMock( return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")]) @@ -459,6 +487,8 @@ def test_entity_function_handles_run_agent_operation(self) -> None: # Verify result was set assert mock_context.set_result.called assert mock_context.set_state.called + result_call = mock_context.set_result.call_args[0][0] + assert "error" not in result_call def test_entity_function_handles_reset_operation(self) -> None: """Test that the entity function handles the reset operation.""" @@ -586,7 +616,7 @@ async def test_entity_handles_agent_error(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Test message", "thread_id": "conv-1", "correlationId": "corr-app-error-1"} ) @@ -606,7 +636,7 @@ def test_entity_function_handles_exception(self) -> None: entity_function = create_agent_entity(mock_agent) mock_context = Mock() - mock_context.operation_name = "run_agent" + mock_context.operation_name = "run" mock_context.get_input.side_effect = Exception("Input error") mock_context.get_state.return_value = None diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index 10b42e2b0e..9d779695df 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -108,6 +108,33 @@ def test_init_with_different_agent_types(self) -> None: class TestAgentEntityRunAgent: """Test suite for the run_agent operation.""" + async def test_run_executes_agent(self) -> None: + """Test that run executes the agent.""" + mock_agent = Mock() + mock_response = _agent_response("Test response") + mock_agent.run = AsyncMock(return_value=mock_response) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + result = await entity.run( + mock_context, {"message": "Test message", "thread_id": "conv-123", "correlationId": "corr-entity-1"} + ) + + # Verify agent.run was called + mock_agent.run.assert_called_once() + _, kwargs = mock_agent.run.call_args + sent_messages: list[Any] = kwargs.get("messages") + assert len(sent_messages) == 1 + sent_message = sent_messages[0] + assert isinstance(sent_message, ChatMessage) + assert getattr(sent_message, "text", None) == "Test message" + assert getattr(sent_message.role, "value", sent_message.role) == "user" + + # Verify result + assert isinstance(result, AgentRunResponse) + assert result.text == "Test response" + async def test_run_agent_executes_agent(self) -> None: """Test that run_agent executes the agent.""" mock_agent = Mock() @@ -156,7 +183,7 @@ async def update_generator() -> AsyncIterator[AgentRunResponseUpdate]: entity = AgentEntity(mock_agent, callback=callback) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, { "message": "Tell me something", @@ -203,7 +230,7 @@ async def test_run_agent_final_callback_without_streaming(self) -> None: entity = AgentEntity(mock_agent, callback=callback) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, { "message": "Hi", @@ -235,7 +262,7 @@ async def test_run_agent_updates_conversation_history(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - await entity.run_agent( + await entity.run( mock_context, {"message": "User message", "thread_id": "conv-1", "correlationId": "corr-entity-2"} ) @@ -263,17 +290,17 @@ async def test_run_agent_increments_message_count(self) -> None: assert len(entity.state.data.conversation_history) == 0 - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-3a"} ) assert len(entity.state.data.conversation_history) == 2 - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-3b"} ) assert len(entity.state.data.conversation_history) == 4 - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-3c"} ) assert len(entity.state.data.conversation_history) == 6 @@ -287,9 +314,7 @@ async def test_run_agent_with_none_thread_id(self) -> None: mock_context = Mock() with pytest.raises(ValueError, match="thread_id"): - await entity.run_agent( - mock_context, {"message": "Message", "thread_id": None, "correlationId": "corr-entity-5"} - ) + await entity.run(mock_context, {"message": "Message", "thread_id": None, "correlationId": "corr-entity-5"}) async def test_run_agent_multiple_conversations(self) -> None: """Test that run_agent maintains history across multiple messages.""" @@ -300,13 +325,13 @@ async def test_run_agent_multiple_conversations(self) -> None: mock_context = Mock() # Send multiple messages - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-8a"} ) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-8b"} ) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-8c"} ) @@ -374,10 +399,10 @@ async def test_reset_after_conversation(self) -> None: mock_context = Mock() # Have a conversation - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-10a"} ) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-10b"} ) @@ -413,7 +438,7 @@ def test_entity_function_handles_run_agent(self) -> None: # Mock context mock_context = Mock() - mock_context.operation_name = "run_agent" + mock_context.operation_name = "run" mock_context.get_input.return_value = { "message": "Test message", "thread_id": "conv-123", @@ -576,7 +601,7 @@ async def test_run_agent_handles_agent_exception(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-1"} ) @@ -595,7 +620,7 @@ async def test_run_agent_handles_value_error(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-2"} ) @@ -614,7 +639,7 @@ async def test_run_agent_handles_timeout_error(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-3"} ) @@ -631,7 +656,7 @@ def test_entity_function_handles_exception_in_operation(self) -> None: entity_function = create_agent_entity(mock_agent) mock_context = Mock() - mock_context.operation_name = "run_agent" + mock_context.operation_name = "run" mock_context.get_input.side_effect = Exception("Input error") mock_context.get_state.return_value = None @@ -651,7 +676,7 @@ async def test_run_agent_preserves_message_on_error(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - result = await entity.run_agent( + result = await entity.run( mock_context, {"message": "Test message", "thread_id": "conv-123", "correlationId": "corr-entity-error-4"}, ) @@ -674,7 +699,7 @@ async def test_conversation_history_has_timestamps(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - await entity.run_agent( + await entity.run( mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-history-1"} ) @@ -694,19 +719,19 @@ async def test_conversation_history_ordering(self) -> None: # Send multiple messages with different responses mock_agent.run = AsyncMock(return_value=_agent_response("Response 1")) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-history-2a"}, ) mock_agent.run = AsyncMock(return_value=_agent_response("Response 2")) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-history-2b"}, ) mock_agent.run = AsyncMock(return_value=_agent_response("Response 3")) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-history-2c"}, ) @@ -729,11 +754,11 @@ async def test_conversation_history_role_alternation(self) -> None: entity = AgentEntity(mock_agent) mock_context = Mock() - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-history-3a"}, ) - await entity.run_agent( + await entity.run( mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-history-3b"}, ) @@ -766,7 +791,7 @@ async def test_run_agent_with_run_request_object(self) -> None: correlation_id="corr-runreq-1", ) - result = await entity.run_agent(mock_context, request) + result = await entity.run(mock_context, request) assert isinstance(result, AgentRunResponse) assert result.text == "Response" @@ -787,7 +812,7 @@ async def test_run_agent_with_dict_request(self) -> None: "correlationId": "corr-runreq-2", } - result = await entity.run_agent(mock_context, request_dict) + result = await entity.run(mock_context, request_dict) assert isinstance(result, AgentRunResponse) assert result.text == "Response" @@ -801,7 +826,7 @@ async def test_run_agent_with_string_raises_without_correlation(self) -> None: mock_context = Mock() with pytest.raises(ValueError): - await entity.run_agent(mock_context, "Simple message") + await entity.run(mock_context, "Simple message") async def test_run_agent_stores_role_in_history(self) -> None: """Test that run_agent stores the role in conversation history.""" @@ -819,7 +844,7 @@ async def test_run_agent_stores_role_in_history(self) -> None: correlation_id="corr-runreq-3", ) - await entity.run_agent(mock_context, request) + await entity.run(mock_context, request) # Check that system role was stored history = entity.state.data.conversation_history @@ -842,7 +867,7 @@ async def test_run_agent_with_response_format(self) -> None: correlation_id="corr-runreq-4", ) - result = await entity.run_agent(mock_context, request) + result = await entity.run(mock_context, request) assert isinstance(result, AgentRunResponse) assert result.text == '{"answer": 42}' @@ -860,7 +885,7 @@ async def test_run_agent_disable_tool_calls(self) -> None: message="Test", thread_id="conv-runreq-5", enable_tool_calls=False, correlation_id="corr-runreq-5" ) - result = await entity.run_agent(mock_context, request) + result = await entity.run(mock_context, request) assert isinstance(result, AgentRunResponse) # Agent should have been called (tool disabling is framework-dependent) @@ -874,7 +899,7 @@ async def test_entity_function_with_run_request_dict(self) -> None: entity_function = create_agent_entity(mock_agent) mock_context = Mock() - mock_context.operation_name = "run_agent" + mock_context.operation_name = "run" mock_context.get_input.return_value = { "message": "Test message", "thread_id": "conv-789", diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index 0f845d4105..b0dd313b0b 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -295,7 +295,7 @@ def test_run_creates_entity_call(self) -> None: call_args = mock_context.call_entity.call_args entity_id, operation, request = call_args[0] - assert operation == "run_agent" + assert operation == "run" assert request["message"] == "Test message" assert request["enable_tool_calls"] is True assert "correlationId" in request diff --git a/python/packages/chatkit/pyproject.toml b/python/packages/chatkit/pyproject.toml index 9bb72d7500..8c73a2ffb0 100644 --- a/python/packages/chatkit/pyproject.toml +++ b/python/packages/chatkit/pyproject.toml @@ -4,7 +4,7 @@ description = "OpenAI ChatKit integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251209" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/copilotstudio/pyproject.toml b/python/packages/copilotstudio/pyproject.toml index 4301111482..c8517c6eea 100644 --- a/python/packages/copilotstudio/pyproject.toml +++ b/python/packages/copilotstudio/pyproject.toml @@ -4,7 +4,7 @@ description = "Copilot Studio integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251209" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 249ee9ecfb..aadd1be40a 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -34,7 +34,7 @@ ToolMode, ) from .exceptions import AgentExecutionException, AgentInitializationError -from .observability import use_agent_observability +from .observability import use_agent_instrumentation if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover @@ -516,8 +516,8 @@ def _prepare_context_providers( @use_agent_middleware -@use_agent_observability -class ChatAgent(BaseAgent): +@use_agent_instrumentation(capture_usage=False) # type: ignore[arg-type,misc] +class ChatAgent(BaseAgent): # type: ignore[misc] """A Chat Client Agent. This is the primary agent implementation that uses a chat client to interact @@ -583,7 +583,7 @@ def get_weather(location: str) -> str: print(update.text, end="") """ - AGENT_SYSTEM_NAME: ClassVar[str] = "microsoft.agent_framework" + AGENT_PROVIDER_NAME: ClassVar[str] = "microsoft.agent_framework" def __init__( self, @@ -878,6 +878,9 @@ async def run( user=user, additional_properties=merged_additional_options, # type: ignore[arg-type] ) + + # Ensure thread is forwarded in kwargs for tool invocation + kwargs["thread"] = thread # Filter chat_options from kwargs to prevent duplicate keyword argument filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} response = await self.chat_client.get_response( @@ -895,7 +898,12 @@ async def run( # Only notify the thread of new messages if the chatResponse was successful # to avoid inconsistent messages state in the thread. - await self._notify_thread_of_new_messages(thread, input_messages, response.messages) + await self._notify_thread_of_new_messages( + thread, + input_messages, + response.messages, + **{k: v for k, v in kwargs.items() if k != "thread"}, + ) return AgentRunResponse( messages=response.messages, response_id=response.response_id, @@ -1017,6 +1025,8 @@ async def run_stream( additional_properties=merged_additional_options, # type: ignore[arg-type] ) + # Ensure thread is forwarded in kwargs for tool invocation + kwargs["thread"] = thread # Filter chat_options from kwargs to prevent duplicate keyword argument filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} response_updates: list[ChatResponseUpdate] = [] @@ -1043,7 +1053,13 @@ async def run_stream( response = ChatResponse.from_chat_response_updates(response_updates, output_format_type=co.response_format) await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id) - await self._notify_thread_of_new_messages(thread, input_messages, response.messages, **kwargs) + + await self._notify_thread_of_new_messages( + thread, + input_messages, + response.messages, + **{k: v for k, v in kwargs.items() if k != "thread"}, + ) @override def get_new_thread( diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 4d91492822..506a1be7cd 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -8,7 +8,6 @@ from pydantic import BaseModel from ._logging import get_logger -from ._mcp import MCPTool from ._memory import AggregateContextProvider, ContextProvider from ._middleware import ( ChatMiddleware, @@ -426,6 +425,8 @@ async def _normalize_tools( else [tools] ) for tool in tools_list: # type: ignore[reportUnknownType] + from ._mcp import MCPTool + if isinstance(tool, MCPTool): if not tool.is_connected: await tool.connect() diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index b4caaea4f5..37e0d2c54b 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -import json import logging import re import sys @@ -19,9 +18,9 @@ from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.session import RequestResponder -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel, create_model -from ._tools import AIFunction, HostedMCPSpecificApproval +from ._tools import AIFunction, HostedMCPSpecificApproval, _build_pydantic_model_from_json_schema from ._types import ( ChatMessage, Contents, @@ -153,7 +152,7 @@ def _mcp_type_to_ai_content( case types.ImageContent() | types.AudioContent(): return_types.append( DataContent( - uri=mcp_type.data, + data=mcp_type.data, media_type=mcp_type.mimeType, raw_representation=mcp_type, ) @@ -274,95 +273,26 @@ def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]: if not prompt.arguments: return create_model(f"{prompt.name}_input") - field_definitions: dict[str, Any] = {} - for prompt_argument in prompt.arguments: - # For prompts, all arguments are typically required and string type - # unless specified otherwise in the prompt argument - python_type = str # Default type for prompt arguments + # Convert prompt arguments to JSON schema format + properties: dict[str, Any] = {} + required: list[str] = [] - # Create field definition for create_model + for prompt_argument in prompt.arguments: + # For prompts, all arguments are typically string type unless specified otherwise + properties[prompt_argument.name] = { + "type": "string", + "description": prompt_argument.description if hasattr(prompt_argument, "description") else "", + } if prompt_argument.required: - field_definitions[prompt_argument.name] = (python_type, ...) - else: - field_definitions[prompt_argument.name] = (python_type, None) + required.append(prompt_argument.name) - return create_model(f"{prompt.name}_input", **field_definitions) + schema = {"properties": properties, "required": required} + return _build_pydantic_model_from_json_schema(prompt.name, schema) def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]: """Creates a Pydantic model from a tools parameters.""" - properties = tool.inputSchema.get("properties", None) - required = tool.inputSchema.get("required", []) - definitions = tool.inputSchema.get("$defs", {}) - - # Check if 'properties' is missing or not a dictionary - if not properties: - return create_model(f"{tool.name}_input") - - def resolve_type(prop_details: dict[str, Any]) -> type: - """Resolve JSON Schema type to Python type, handling $ref.""" - # Handle $ref by resolving the reference - if "$ref" in prop_details: - ref = prop_details["$ref"] - # Extract the reference path (e.g., "#/$defs/CustomerIdParam" -> "CustomerIdParam") - if ref.startswith("#/$defs/"): - def_name = ref.split("/")[-1] - if def_name in definitions: - # Resolve the reference and use its type - resolved = definitions[def_name] - return resolve_type(resolved) - # If we can't resolve the ref, default to dict for safety - return dict - - # Map JSON Schema types to Python types - json_type = prop_details.get("type", "string") - match json_type: - case "integer": - return int - case "number": - return float - case "boolean": - return bool - case "array": - return list - case "object": - return dict - case _: - return str # default - - field_definitions: dict[str, Any] = {} - for prop_name, prop_details in properties.items(): - prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details - - python_type = resolve_type(prop_details) - description = prop_details.get("description", "") - - # Build field kwargs (description, array items schema, etc.) - field_kwargs: dict[str, Any] = {} - if description: - field_kwargs["description"] = description - - # Preserve array items schema if present - if prop_details.get("type") == "array" and "items" in prop_details: - items_schema = prop_details["items"] - if items_schema and items_schema != {}: - field_kwargs["json_schema_extra"] = {"items": items_schema} - - # Create field definition for create_model - if prop_name in required: - if field_kwargs: - field_definitions[prop_name] = (python_type, Field(**field_kwargs)) - else: - field_definitions[prop_name] = (python_type, ...) - else: - default_value = prop_details.get("default", None) - field_kwargs["default"] = default_value - if field_kwargs and any(k != "default" for k in field_kwargs): - field_definitions[prop_name] = (python_type, Field(**field_kwargs)) - else: - field_definitions[prop_name] = (python_type, default_value) - - return create_model(f"{tool.name}_input", **field_definitions) + return _build_pydantic_model_from_json_schema(tool.name, tool.inputSchema) def _normalize_mcp_name(name: str) -> str: @@ -755,8 +685,16 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]: raise ToolExecutionException( "Tools are not loaded for this server, please set load_tools=True in the constructor." ) + # Filter out framework kwargs that cannot be serialized by the MCP SDK. + # These are internal objects passed through the function invocation pipeline + # that should not be forwarded to external MCP servers. + filtered_kwargs = { + k: v for k, v in kwargs.items() if k not in {"chat_options", "tools", "tool_choice", "thread"} + } try: - return _mcp_call_tool_result_to_ai_contents(await self.session.call_tool(tool_name, arguments=kwargs)) + return _mcp_call_tool_result_to_ai_contents( + await self.session.call_tool(tool_name, arguments=filtered_kwargs) + ) except McpError as mcp_exc: raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc except Exception as ex: diff --git a/python/packages/core/agent_framework/_memory.py b/python/packages/core/agent_framework/_memory.py index 4b2a01ad24..a5b53fc39f 100644 --- a/python/packages/core/agent_framework/_memory.py +++ b/python/packages/core/agent_framework/_memory.py @@ -6,11 +6,13 @@ from collections.abc import MutableSequence, Sequence from contextlib import AsyncExitStack from types import TracebackType -from typing import Any, Final, cast +from typing import TYPE_CHECKING, Any, Final, cast -from ._tools import ToolProtocol from ._types import ChatMessage +if TYPE_CHECKING: + from ._tools import ToolProtocol + if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: @@ -54,7 +56,7 @@ def __init__( self, instructions: str | None = None, messages: Sequence[ChatMessage] | None = None, - tools: Sequence[ToolProtocol] | None = None, + tools: Sequence["ToolProtocol"] | None = None, ): """Create a new Context object. @@ -65,7 +67,7 @@ def __init__( """ self.instructions = instructions self.messages: Sequence[ChatMessage] = messages or [] - self.tools: Sequence[ToolProtocol] = tools or [] + self.tools: Sequence["ToolProtocol"] = tools or [] # region ContextProvider @@ -247,7 +249,7 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], * contexts = await asyncio.gather(*[provider.invoking(messages, **kwargs) for provider in self.providers]) instructions: str = "" return_messages: list[ChatMessage] = [] - tools: list[ToolProtocol] = [] + tools: list["ToolProtocol"] = [] for ctx in contexts: if ctx.instructions: instructions += ctx.instructions diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 9bb730ba62..4e36cb764a 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -1405,13 +1405,17 @@ async def _stream_generator() -> Any: call_middleware = kwargs.pop("middleware", None) instance_middleware = getattr(self, "middleware", None) - # Merge middleware from both sources, filtering for chat middleware only - all_middleware: list[ChatMiddleware | ChatMiddlewareCallable] = _merge_and_filter_chat_middleware( - instance_middleware, call_middleware - ) + # Merge all middleware and separate by type + middleware = categorize_middleware(instance_middleware, call_middleware) + chat_middleware_list = middleware["chat"] + function_middleware_list = middleware["function"] + + # Pass function middleware to function invocation system if present + if function_middleware_list: + kwargs["_function_middleware_pipeline"] = FunctionMiddlewarePipeline(function_middleware_list) - # If no middleware, use original method - if not all_middleware: + # If no chat middleware, use original method + if not chat_middleware_list: async for update in original_get_streaming_response(self, messages, **kwargs): yield update return @@ -1422,7 +1426,7 @@ async def _stream_generator() -> Any: # Extract chat_options or create default chat_options = kwargs.pop("chat_options", ChatOptions()) - pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type] + pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type] context = ChatContext( chat_client=self, messages=prepare_messages(messages), @@ -1536,27 +1540,40 @@ def _merge_and_filter_chat_middleware( return middleware["chat"] # type: ignore[return-value] -def extract_and_merge_function_middleware(chat_client: Any, **kwargs: Any) -> None: +def extract_and_merge_function_middleware( + chat_client: Any, kwargs: dict[str, Any] +) -> "FunctionMiddlewarePipeline | None": """Extract function middleware from chat client and merge with existing pipeline in kwargs. Args: chat_client: The chat client instance to extract middleware from. + kwargs: Dictionary containing middleware and pipeline information. - Keyword Args: - **kwargs: Dictionary containing middleware and pipeline information. + Returns: + A FunctionMiddlewarePipeline if function middleware is found, None otherwise. """ + # Check if a pipeline was already created by use_chat_middleware + existing_pipeline: FunctionMiddlewarePipeline | None = kwargs.get("_function_middleware_pipeline") + # Get middleware sources client_middleware = getattr(chat_client, "middleware", None) if hasattr(chat_client, "middleware") else None run_level_middleware = kwargs.get("middleware") - existing_pipeline = kwargs.get("_function_middleware_pipeline") - # Extract existing pipeline middlewares if present - existing_middlewares = existing_pipeline._middlewares if existing_pipeline else None + # If we have an existing pipeline but no additional middleware sources, return it directly + if existing_pipeline and not client_middleware and not run_level_middleware: + return existing_pipeline + + # If we have an existing pipeline with additional middleware, we need to merge + # Extract existing pipeline middlewares if present - cast to list[Middleware] for type compatibility + existing_middlewares: list[Middleware] | None = list(existing_pipeline._middlewares) if existing_pipeline else None # Create combined pipeline from all sources using existing helper combined_pipeline = create_function_middleware_pipeline( client_middleware, run_level_middleware, existing_middlewares ) - if combined_pipeline: - kwargs["_function_middleware_pipeline"] = combined_pipeline + # If we have an existing pipeline but combined is None (no new middlewares), return existing + if existing_pipeline and combined_pipeline is None: + return existing_pipeline + + return combined_pipeline diff --git a/python/packages/core/agent_framework/_serialization.py b/python/packages/core/agent_framework/_serialization.py index 1a38d9030a..cf28df2f4f 100644 --- a/python/packages/core/agent_framework/_serialization.py +++ b/python/packages/core/agent_framework/_serialization.py @@ -339,11 +339,17 @@ def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) continue # Handle dicts containing SerializationProtocol values if isinstance(value, dict): + from datetime import date, datetime, time + serialized_dict: dict[str, Any] = {} for k, v in value.items(): if isinstance(v, SerializationProtocol): serialized_dict[k] = v.to_dict(exclude=exclude, exclude_none=exclude_none) continue + # Convert datetime objects to strings + if isinstance(v, (datetime, date, time)): + serialized_dict[k] = str(v) + continue # Check if the value is JSON serializable if is_serializable(v): serialized_dict[k] = v diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 3657a994e2..2f7801c84b 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -25,7 +25,6 @@ from opentelemetry.metrics import Histogram from pydantic import AnyUrl, BaseModel, Field, ValidationError, create_model -from pydantic.fields import FieldInfo from ._logging import get_logger from ._serialization import SerializationMixin @@ -628,6 +627,12 @@ def __init__( self._invocation_duration_histogram = _default_histogram() self.type: Literal["ai_function"] = "ai_function" self._forward_runtime_kwargs: bool = False + if self.func: + sig = inspect.signature(self.func) + for param in sig.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + self._forward_runtime_kwargs = True + break @property def declaration_only(self) -> bool: @@ -881,6 +886,8 @@ def _parse_annotation(annotation: Any) -> Any: If the second annotation (after the type) is a string, then we convert that to a Pydantic Field description. The rest are returned as-is, allowing for multiple annotations. + Literal types are returned as-is to preserve their enum-like values. + Args: annotation: The type annotation to parse. @@ -889,6 +896,12 @@ def _parse_annotation(annotation: Any) -> Any: """ origin = get_origin(annotation) if origin is not None: + # Literal types should be returned as-is - their args are the allowed values, + # not type annotations to be parsed. For example, Literal["Data", "Security"] + # has args ("Data", "Security") which are the valid string values. + if origin is Literal: + return annotation + args = get_args(annotation) # For other generics, return the origin type (e.g., list for List[int]) if len(args) > 1 and isinstance(args[1], str): @@ -916,6 +929,7 @@ def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[B ) for pname, param in sig.parameters.items() if pname not in {"self", "cls"} + and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} } return create_model(f"{name}_input", **fields) # type: ignore[call-overload, no-any-return] @@ -932,6 +946,151 @@ def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[B } +def _build_pydantic_model_from_json_schema( + model_name: str, + schema: Mapping[str, Any], +) -> type[BaseModel]: + """Creates a Pydantic model from JSON Schema with support for $refs, nested objects, and typed arrays. + + Args: + model_name: The name of the model to be created. + schema: The JSON Schema definition (should contain 'properties', 'required', '$defs', etc.). + + Returns: + The dynamically created Pydantic model class. + """ + properties = schema.get("properties") + required = schema.get("required", []) + definitions = schema.get("$defs", {}) + + # Check if 'properties' is missing or not a dictionary + if not properties: + return create_model(f"{model_name}_input") + + def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: + """Resolve JSON Schema type to Python type, handling $ref, nested objects, and typed arrays. + + Args: + prop_details: The JSON Schema property details + parent_name: Name to use for creating nested models (for uniqueness) + + Returns: + Python type annotation (could be int, str, list[str], or a nested Pydantic model) + """ + # Handle $ref by resolving the reference + if "$ref" in prop_details: + ref = prop_details["$ref"] + # Extract the reference path (e.g., "#/$defs/CustomerIdParam" -> "CustomerIdParam") + if ref.startswith("#/$defs/"): + def_name = ref.split("/")[-1] + if def_name in definitions: + # Resolve the reference and use its type + resolved = definitions[def_name] + return _resolve_type(resolved, def_name) + # If we can't resolve the ref, default to dict for safety + return dict + + # Map JSON Schema types to Python types + json_type = prop_details.get("type", "string") + match json_type: + case "integer": + return int + case "number": + return float + case "boolean": + return bool + case "array": + # Handle typed arrays + items_schema = prop_details.get("items") + if items_schema and isinstance(items_schema, dict): + # Recursively resolve the item type + item_type = _resolve_type(items_schema, f"{parent_name}_item") + # Return list[ItemType] instead of bare list + return list[item_type] # type: ignore + # If no items schema or invalid, return bare list + return list + case "object": + # Handle nested objects by creating a nested Pydantic model + nested_properties = prop_details.get("properties") + nested_required = prop_details.get("required", []) + + if nested_properties and isinstance(nested_properties, dict): + # Create the name for the nested model + nested_model_name = f"{parent_name}_nested" if parent_name else "NestedModel" + + # Recursively build field definitions for the nested model + nested_field_definitions: dict[str, Any] = {} + for nested_prop_name, nested_prop_details in nested_properties.items(): + nested_prop_details = ( + json.loads(nested_prop_details) + if isinstance(nested_prop_details, str) + else nested_prop_details + ) + + nested_python_type = _resolve_type( + nested_prop_details, f"{nested_model_name}_{nested_prop_name}" + ) + nested_description = nested_prop_details.get("description", "") + + # Build field kwargs for nested property + nested_field_kwargs: dict[str, Any] = {} + if nested_description: + nested_field_kwargs["description"] = nested_description + + # Create field definition + if nested_prop_name in nested_required: + nested_field_definitions[nested_prop_name] = ( + ( + nested_python_type, + Field(**nested_field_kwargs), + ) + if nested_field_kwargs + else (nested_python_type, ...) + ) + else: + nested_field_kwargs["default"] = nested_prop_details.get("default", None) + nested_field_definitions[nested_prop_name] = ( + nested_python_type, + Field(**nested_field_kwargs), + ) + + # Create and return the nested Pydantic model + return create_model(nested_model_name, **nested_field_definitions) # type: ignore + + # If no properties defined, return bare dict + return dict + case _: + return str # default + + field_definitions: dict[str, Any] = {} + for prop_name, prop_details in properties.items(): + prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details + + python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}") + description = prop_details.get("description", "") + + # Build field kwargs (description, etc.) + field_kwargs: dict[str, Any] = {} + if description: + field_kwargs["description"] = description + + # Create field definition for create_model + if prop_name in required: + if field_kwargs: + field_definitions[prop_name] = (python_type, Field(**field_kwargs)) + else: + field_definitions[prop_name] = (python_type, ...) + else: + default_value = prop_details.get("default", None) + field_kwargs["default"] = default_value + if field_kwargs and any(k != "default" for k in field_kwargs): + field_definitions[prop_name] = (python_type, Field(**field_kwargs)) + else: + field_definitions[prop_name] = (python_type, default_value) + + return create_model(f"{model_name}_input", **field_definitions) + + def _create_model_from_json_schema(tool_name: str, schema_json: Mapping[str, Any]) -> type[BaseModel]: """Creates a Pydantic model from a given JSON Schema. @@ -948,29 +1107,8 @@ def _create_model_from_json_schema(tool_name: str, schema_json: Mapping[str, Any f"JSON schema for tool '{tool_name}' must contain a 'properties' key of type dict. " f"Got: {schema_json.get('properties', None)}" ) - # Extract field definitions with type annotations - field_definitions: dict[str, tuple[type, FieldInfo]] = {} - for field_name, field_schema in schema_json["properties"].items(): - field_args: dict[str, Any] = {} - if (field_description := field_schema.get("description", None)) is not None: - field_args["description"] = field_description - if (field_default := field_schema.get("default", None)) is not None: - field_args["default"] = field_default - field_type = field_schema.get("type", None) - if field_type is None: - raise ValueError( - f"Missing 'type' for field '{field_name}' in JSON schema. " - f"Got: {field_schema}, Supported types: {list(TYPE_MAPPING.keys())}" - ) - python_type = TYPE_MAPPING.get(field_type) - if python_type is None: - raise ValueError( - f"Unsupported type '{field_type}' for field '{field_name}' in JSON schema. " - f"Got: {field_schema}, Supported types: {list(TYPE_MAPPING.keys())}" - ) - field_definitions[field_name] = (python_type, Field(**field_args)) - return create_model(f"{tool_name}_input", **field_definitions) # type: ignore[call-overload, no-any-return] + return _build_pydantic_model_from_json_schema(tool_name, schema_json) @overload @@ -1218,6 +1356,35 @@ def __init__( self.include_detailed_errors = include_detailed_errors +class FunctionExecutionResult: + """Internal wrapper pairing function output with loop control signals. + + Function execution produces two distinct concerns: the semantic result (returned to + the LLM as FunctionResultContent) and control flow decisions (whether middleware + requested early termination). This wrapper keeps control signals out of user-facing + content types while allowing _try_execute_function_calls to communicate both. + + Not exposed to users. + + Attributes: + content: The FunctionResultContent or other content from the function execution. + terminate: If True, the function invocation loop should exit immediately without + another LLM call. Set when middleware sets context.terminate=True. + """ + + __slots__ = ("content", "terminate") + + def __init__(self, content: "Contents", terminate: bool = False) -> None: + """Initialize FunctionExecutionResult. + + Args: + content: The content from the function execution. + terminate: Whether to terminate the function calling loop. + """ + self.content = content + self.terminate = terminate + + async def _auto_invoke_function( function_call_content: "FunctionCallContent | FunctionApprovalResponseContent", custom_args: dict[str, Any] | None = None, @@ -1227,7 +1394,7 @@ async def _auto_invoke_function( sequence_index: int | None = None, request_index: int | None = None, middleware_pipeline: Any = None, # Optional MiddlewarePipeline -) -> "Contents": +) -> "FunctionExecutionResult | Contents": """Invoke a function call requested by the agent, applying middleware that is defined. Args: @@ -1242,7 +1409,8 @@ async def _auto_invoke_function( middleware_pipeline: Optional middleware pipeline to apply during execution. Returns: - A FunctionResultContent containing the result or exception. + A FunctionExecutionResult wrapping the content and terminate signal, + or a Contents object for approval/hosted tool scenarios. Raises: KeyError: If the requested function is not found in the tool map. @@ -1262,10 +1430,12 @@ async def _auto_invoke_function( # Tool should exist because _try_execute_function_calls validates this if tool is None: exc = KeyError(f'Function "{function_call_content.name}" not found.') - return FunctionResultContent( - call_id=function_call_content.call_id, - result=f'Error: Requested function "{function_call_content.name}" not found.', - exception=exc, + return FunctionExecutionResult( + content=FunctionResultContent( + call_id=function_call_content.call_id, + result=f'Error: Requested function "{function_call_content.name}" not found.', + exception=exc, + ) ) else: # Note: Unapproved tools (approved=False) are handled in _replace_approval_contents_with_results @@ -1290,7 +1460,9 @@ async def _auto_invoke_function( message = "Error: Argument parsing failed." if config.include_detailed_errors: message = f"{message} Exception: {exc}" - return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + return FunctionExecutionResult( + content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + ) if not middleware_pipeline or ( not hasattr(middleware_pipeline, "has_middlewares") and not middleware_pipeline.has_middlewares @@ -1302,15 +1474,19 @@ async def _auto_invoke_function( tool_call_id=function_call_content.call_id, **runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, ) - return FunctionResultContent( - call_id=function_call_content.call_id, - result=function_result, + return FunctionExecutionResult( + content=FunctionResultContent( + call_id=function_call_content.call_id, + result=function_result, + ) ) except Exception as exc: message = "Error: Function failed." if config.include_detailed_errors: message = f"{message} Exception: {exc}" - return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + return FunctionExecutionResult( + content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + ) # Execute through middleware pipeline if available from ._middleware import FunctionInvocationContext @@ -1334,15 +1510,20 @@ async def final_function_handler(context_obj: Any) -> Any: context=middleware_context, final_handler=final_function_handler, ) - return FunctionResultContent( - call_id=function_call_content.call_id, - result=function_result, + return FunctionExecutionResult( + content=FunctionResultContent( + call_id=function_call_content.call_id, + result=function_result, + ), + terminate=middleware_context.terminate, ) except Exception as exc: message = "Error: Function failed." if config.include_detailed_errors: message = f"{message} Exception: {exc}" - return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + return FunctionExecutionResult( + content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + ) def _get_tool_map( @@ -1373,7 +1554,7 @@ async def _try_execute_function_calls( | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]", config: FunctionInvocationConfiguration, middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports -) -> Sequence["Contents"]: +) -> tuple[Sequence["Contents"], bool]: """Execute multiple function calls concurrently. Args: @@ -1385,9 +1566,11 @@ async def _try_execute_function_calls( middleware_pipeline: Optional middleware pipeline to apply during execution. Returns: - A list of Contents containing the results of each function call, - or the approval requests if any function requires approval, - or the original function calls if any are declaration only. + A tuple of: + - A list of Contents containing the results of each function call, + or the approval requests if any function requires approval, + or the original function calls if any are declaration only. + - A boolean indicating whether to terminate the function calling loop. """ from ._types import FunctionApprovalRequestContent, FunctionCallContent @@ -1410,17 +1593,20 @@ async def _try_execute_function_calls( raise KeyError(f'Error: Requested function "{fcc.name}" not found.') if approval_needed: # approval can only be needed for Function Call Contents, not Approval Responses. - return [ - FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc) - for fcc in function_calls - if isinstance(fcc, FunctionCallContent) - ] + return ( + [ + FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc) + for fcc in function_calls + if isinstance(fcc, FunctionCallContent) + ], + False, + ) if declaration_only_flag: # return the declaration only tools to the user, since we cannot execute them. - return [fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)] + return ([fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)], False) # Run all function calls concurrently - return await asyncio.gather(*[ + execution_results = await asyncio.gather(*[ _auto_invoke_function( function_call_content=function_call, # type: ignore[arg-type] custom_args=custom_args, @@ -1433,6 +1619,20 @@ async def _try_execute_function_calls( for seq_idx, function_call in enumerate(function_calls) ]) + # Unpack FunctionExecutionResult wrappers and check for terminate signal + contents: list[Contents] = [] + should_terminate = False + for result in execution_results: + if isinstance(result, FunctionExecutionResult): + contents.append(result.content) + if result.terminate: + should_terminate = True + else: + # Direct Contents (e.g., from hosted tools) + contents.append(result) + + return (contents, should_terminate) + def _update_conversation_id(kwargs: dict[str, Any], conversation_id: str | None) -> None: """Update kwargs with conversation id. @@ -1565,12 +1765,8 @@ async def function_invocation_wrapper( prepare_messages, ) - # Extract and merge function middleware from chat client with kwargs pipeline - extract_and_merge_function_middleware(self, **kwargs) - - # Extract the middleware pipeline before calling the underlying function - # because the underlying function may not preserve it in kwargs - stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline") + # Extract and merge function middleware from chat client with kwargs + stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs) # Get the config for function invocation (not part of ChatClientProtocol, hence getattr) config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None) @@ -1596,7 +1792,7 @@ async def function_invocation_wrapper( approved_responses = [resp for resp in fcc_todo.values() if resp.approved] approved_function_results: list[Contents] = [] if approved_responses: - approved_function_results = await _try_execute_function_calls( + results, _ = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=approved_responses, @@ -1604,6 +1800,7 @@ async def function_invocation_wrapper( middleware_pipeline=stored_middleware_pipeline, config=config, ) + approved_function_results = list(results) if any( fcr.exception is not None for fcr in approved_function_results @@ -1621,7 +1818,9 @@ async def function_invocation_wrapper( break _replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results) - response = await func(self, messages=prepped_messages, **kwargs) + # Filter out internal framework kwargs before passing to clients. + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} + response = await func(self, messages=prepped_messages, **filtered_kwargs) # if there are function calls, we will handle them first function_results = { it.call_id for it in response.messages[0].contents if isinstance(it, FunctionResultContent) @@ -1641,7 +1840,7 @@ async def function_invocation_wrapper( if function_calls and tools: # Use the stored middleware pipeline instead of extracting from kwargs # because kwargs may have been modified by the underlying function - function_call_results: list[Contents] = await _try_execute_function_calls( + function_call_results, should_terminate = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=function_calls, @@ -1666,6 +1865,17 @@ async def function_invocation_wrapper( # the function calls are already in the response, so we just continue return response + # Check if middleware signaled to terminate the loop (context.terminate=True) + # This allows middleware to short-circuit the tool loop without another LLM call + if should_terminate: + # Add tool results to response and return immediately without calling LLM again + result_message = ChatMessage(role="tool", contents=function_call_results) + response.messages.append(result_message) + if fcc_messages: + for msg in reversed(fcc_messages): + response.messages.insert(0, msg) + return response + if any( fcr.exception is not None for fcr in function_call_results @@ -1710,7 +1920,10 @@ async def function_invocation_wrapper( # Failsafe: give up on tools, ask model for plain answer kwargs["tool_choice"] = "none" - response = await func(self, messages=prepped_messages, **kwargs) + + # Filter out internal framework kwargs before passing to clients. + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} + response = await func(self, messages=prepped_messages, **filtered_kwargs) if fcc_messages: for msg in reversed(fcc_messages): response.messages.insert(0, msg) @@ -1755,12 +1968,8 @@ async def streaming_function_invocation_wrapper( prepare_messages, ) - # Extract and merge function middleware from chat client with kwargs pipeline - extract_and_merge_function_middleware(self, **kwargs) - - # Extract the middleware pipeline before calling the underlying function - # because the underlying function may not preserve it in kwargs - stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline") + # Extract and merge function middleware from chat client with kwargs + stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs) # Get the config for function invocation (not part of ChatClientProtocol, hence getattr) config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None) @@ -1779,7 +1988,7 @@ async def streaming_function_invocation_wrapper( approved_responses = [resp for resp in fcc_todo.values() if resp.approved] approved_function_results: list[Contents] = [] if approved_responses: - approved_function_results = await _try_execute_function_calls( + results, _ = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=approved_responses, @@ -1787,6 +1996,7 @@ async def streaming_function_invocation_wrapper( middleware_pipeline=stored_middleware_pipeline, config=config, ) + approved_function_results = list(results) if any( fcr.exception is not None for fcr in approved_function_results @@ -1797,7 +2007,9 @@ async def streaming_function_invocation_wrapper( _replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results) all_updates: list["ChatResponseUpdate"] = [] - async for update in func(self, messages=prepped_messages, **kwargs): + # Filter out internal framework kwargs before passing to clients. + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} + async for update in func(self, messages=prepped_messages, **filtered_kwargs): all_updates.append(update) yield update @@ -1839,7 +2051,7 @@ async def streaming_function_invocation_wrapper( if function_calls and tools: # Use the stored middleware pipeline instead of extracting from kwargs # because kwargs may have been modified by the underlying function - function_call_results: list[Contents] = await _try_execute_function_calls( + function_call_results, should_terminate = await _try_execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=function_calls, @@ -1868,6 +2080,13 @@ async def streaming_function_invocation_wrapper( # the function calls were already yielded. return + # Check if middleware signaled to terminate the loop (context.terminate=True) + # This allows middleware to short-circuit the tool loop without another LLM call + if should_terminate: + # Yield tool results and return immediately without calling LLM again + yield ChatResponseUpdate(contents=function_call_results, role="tool") + return + if any( fcr.exception is not None for fcr in function_call_results @@ -1908,7 +2127,9 @@ async def streaming_function_invocation_wrapper( # Failsafe: give up on tools, ask model for plain answer kwargs["tool_choice"] = "none" - async for update in func(self, messages=prepped_messages, **kwargs): + # Filter out internal framework kwargs before passing to clients. + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"} + async for update in func(self, messages=prepped_messages, **filtered_kwargs): yield update return streaming_function_invocation_wrapper diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index f4662352a0..ab68382a83 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -925,6 +925,10 @@ class DataContent(BaseContent): image_data = b"raw image bytes" data_content = DataContent(data=image_data, media_type="image/png") + # Create from base64-encoded string + base64_string = "iVBORw0KGgoAAAANS..." + data_content = DataContent(data=base64_string, media_type="image/png") + # Create from data URI data_uri = "..." data_content = DataContent(uri=data_uri) @@ -986,11 +990,38 @@ def __init__( **kwargs: Any additional keyword arguments. """ + @overload + def __init__( + self, + *, + data: str, + media_type: str, + annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, + additional_properties: dict[str, Any] | None = None, + raw_representation: Any | None = None, + **kwargs: Any, + ) -> None: + """Initializes a DataContent instance with base64-encoded string data. + + Important: + This is for binary data that is represented as a data URI, not for online resources. + Use ``UriContent`` for online resources. + + Keyword Args: + data: The base64-encoded string data represented by this instance. + The data is used directly to construct a data URI. + media_type: The media type of the data. + annotations: Optional annotations associated with the content. + additional_properties: Optional additional properties associated with the content. + raw_representation: Optional raw representation of the content. + **kwargs: Any additional keyword arguments. + """ + def __init__( self, *, uri: str | None = None, - data: bytes | None = None, + data: bytes | str | None = None, media_type: str | None = None, annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, additional_properties: dict[str, Any] | None = None, @@ -1006,8 +1037,9 @@ def __init__( Keyword Args: uri: The URI of the data represented by this instance. Should be in the form: "data:{media_type};base64,{base64_data}". - data: The binary data represented by this instance. - The data is transformed into a base64-encoded data URI. + data: The binary data or base64-encoded string represented by this instance. + If bytes, the data is transformed into a base64-encoded data URI. + If str, it is assumed to be already base64-encoded and used directly. media_type: The media type of the data. annotations: Optional annotations associated with the content. additional_properties: Optional additional properties associated with the content. @@ -1017,7 +1049,9 @@ def __init__( if uri is None: if data is None or media_type is None: raise ValueError("Either 'data' and 'media_type' or 'uri' must be provided.") - uri = f"data:{media_type};base64,{base64.b64encode(data).decode('utf-8')}" + + base64_data: str = base64.b64encode(data).decode("utf-8") if isinstance(data, bytes) else data + uri = f"data:{media_type};base64,{base64_data}" # Validate URI format and extract media type if not provided validated_uri = self._validate_uri(uri) @@ -1816,13 +1850,14 @@ def prepare_function_call_results(content: Contents | Any | list[Contents | Any] """Prepare the values of the function call results.""" if isinstance(content, Contents): # For BaseContent objects, use to_dict and serialize to JSON - return json.dumps(content.to_dict(exclude={"raw_representation", "additional_properties"})) + # Use default=str to handle datetime and other non-JSON-serializable objects + return json.dumps(content.to_dict(exclude={"raw_representation", "additional_properties"}), default=str) dumpable = _prepare_function_call_results_as_dumpable(content) if isinstance(dumpable, str): return dumpable - # fallback - return json.dumps(dumpable) + # fallback - use default=str to handle datetime and other non-JSON-serializable objects + return json.dumps(dumpable, default=str) # region Chat Response constants diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index 81fe1f3b73..d4f6c1411d 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -13,20 +13,25 @@ AgentRunResponseUpdate, AgentThread, BaseAgent, + BaseContent, ChatMessage, + Contents, FunctionApprovalRequestContent, FunctionApprovalResponseContent, FunctionCallContent, FunctionResultContent, Role, + TextContent, UsageDetails, ) from ..exceptions import AgentExecutionException +from ._checkpoint import CheckpointStorage from ._events import ( AgentRunUpdateEvent, RequestInfoEvent, WorkflowEvent, + WorkflowOutputEvent, ) from ._message_utils import normalize_messages_input from ._typing_utils import is_type_compatible @@ -117,6 +122,8 @@ async def run( messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: AgentThread | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, ) -> AgentRunResponse: """Get a response from the workflow agent (non-streaming). @@ -124,10 +131,16 @@ async def run( This method collects all streaming updates and merges them into a single response. Args: - messages: The message(s) to send to the workflow. + messages: The message(s) to send to the workflow. Required for new runs, + should be None when resuming from checkpoint. Keyword Args: thread: The conversation thread. If None, a new thread will be created. + checkpoint_id: ID of checkpoint to restore from. If provided, the workflow + resumes from this checkpoint instead of starting fresh. + checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id, + used to load and restore the checkpoint. When provided without checkpoint_id, + enables checkpointing for this run. **kwargs: Additional keyword arguments. Returns: @@ -139,7 +152,9 @@ async def run( thread = thread or self.get_new_thread() response_id = str(uuid.uuid4()) - async for update in self._run_stream_impl(input_messages, response_id): + async for update in self._run_stream_impl( + input_messages, response_id, thread, checkpoint_id, checkpoint_storage + ): response_updates.append(update) # Convert updates to final response. @@ -155,15 +170,23 @@ async def run_stream( messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: AgentThread | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, **kwargs: Any, ) -> AsyncIterable[AgentRunResponseUpdate]: """Stream response updates from the workflow agent. Args: - messages: The message(s) to send to the workflow. + messages: The message(s) to send to the workflow. Required for new runs, + should be None when resuming from checkpoint. Keyword Args: thread: The conversation thread. If None, a new thread will be created. + checkpoint_id: ID of checkpoint to restore from. If provided, the workflow + resumes from this checkpoint instead of starting fresh. + checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id, + used to load and restore the checkpoint. When provided without checkpoint_id, + enables checkpointing for this run. **kwargs: Additional keyword arguments. Yields: @@ -174,7 +197,9 @@ async def run_stream( response_updates: list[AgentRunResponseUpdate] = [] response_id = str(uuid.uuid4()) - async for update in self._run_stream_impl(input_messages, response_id): + async for update in self._run_stream_impl( + input_messages, response_id, thread, checkpoint_id, checkpoint_storage + ): response_updates.append(update) yield update @@ -188,12 +213,18 @@ async def _run_stream_impl( self, input_messages: list[ChatMessage], response_id: str, + thread: AgentThread, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, ) -> AsyncIterable[AgentRunResponseUpdate]: """Internal implementation of streaming execution. Args: input_messages: Normalized input messages to process. response_id: The unique response ID for this workflow execution. + thread: The conversation thread containing message history. + checkpoint_id: ID of checkpoint to restore from. + checkpoint_storage: Runtime checkpoint storage. Yields: AgentRunResponseUpdate objects representing the workflow execution progress. @@ -217,10 +248,27 @@ async def _run_stream_impl( # and we will let the workflow to handle this -- the agent does not # have an opinion on this. event_stream = self.workflow.send_responses_streaming(function_responses) + elif checkpoint_id is not None: + # Resume from checkpoint - don't prepend thread history since workflow state + # is being restored from the checkpoint + event_stream = self.workflow.run_stream( + message=None, + checkpoint_id=checkpoint_id, + checkpoint_storage=checkpoint_storage, + ) else: # Execute workflow with streaming (initial run or no function responses) - # Pass the new input messages directly to the workflow - event_stream = self.workflow.run_stream(input_messages) + # Build the complete conversation by prepending thread history to input messages + conversation_messages: list[ChatMessage] = [] + if thread.message_store: + history = await thread.message_store.list_messages() + if history: + conversation_messages.extend(history) + conversation_messages.extend(input_messages) + event_stream = self.workflow.run_stream( + message=conversation_messages, + checkpoint_storage=checkpoint_storage, + ) # Process events from the stream async for event in event_stream: @@ -236,9 +284,8 @@ def _convert_workflow_event_to_agent_update( ) -> AgentRunResponseUpdate | None: """Convert a workflow event to an AgentRunResponseUpdate. - Only AgentRunUpdateEvent and RequestInfoEvent are processed. - Other workflow events are ignored as they are workflow-internal and should - have corresponding AgentRunUpdateEvent emissions if relevant to agent consumers. + AgentRunUpdateEvent, RequestInfoEvent, and WorkflowOutputEvent are processed. + Other workflow events are ignored as they are workflow-internal. """ match event: case AgentRunUpdateEvent(data=update): @@ -247,6 +294,42 @@ def _convert_workflow_event_to_agent_update( return update return None + case WorkflowOutputEvent(data=data, source_executor_id=source_executor_id): + # Convert workflow output to an agent response update. + # Handle different data types appropriately. + if isinstance(data, AgentRunResponseUpdate): + # Already an update, pass through + return data + if isinstance(data, ChatMessage): + # Convert ChatMessage to update + return AgentRunResponseUpdate( + contents=list(data.contents), + role=data.role, + author_name=data.author_name or source_executor_id, + response_id=response_id, + message_id=str(uuid.uuid4()), + created_at=datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + raw_representation=data, + ) + # Determine contents based on data type + if isinstance(data, BaseContent): + # Already a content type (TextContent, ImageContent, etc.) + contents: list[Contents] = [cast(Contents, data)] + elif isinstance(data, str): + contents = [TextContent(text=data)] + else: + # Fallback: convert to string representation + contents = [TextContent(text=str(data))] + return AgentRunResponseUpdate( + contents=contents, + role=Role.ASSISTANT, + author_name=source_executor_id, + response_id=response_id, + message_id=str(uuid.uuid4()), + created_at=datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + raw_representation=data, + ) + case RequestInfoEvent(request_id=request_id): # Store the pending request for later correlation self.pending_requests[request_id] = event diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 358cee94dd..26300ad473 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -11,6 +11,7 @@ from .._threads import AgentThread from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value +from ._const import WORKFLOW_RUN_KWARGS_KEY from ._conversation_state import encode_chat_messages from ._events import ( AgentRunEvent, @@ -105,6 +106,11 @@ def workflow_output_types(self) -> list[type[Any]]: return [AgentRunResponse] return [] + @property + def description(self) -> str | None: + """Get the description of the underlying agent.""" + return self._agent.description + @handler async def run( self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse] @@ -304,9 +310,12 @@ async def _run_agent(self, ctx: WorkflowContext) -> AgentRunResponse | None: Returns: The complete AgentRunResponse, or None if waiting for user input. """ + run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + response = await self._agent.run( self._cache, thread=self._agent_thread, + **run_kwargs, ) await ctx.add_event(AgentRunEvent(self.id, response)) @@ -328,11 +337,14 @@ async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentRunResponse | Returns: The complete AgentRunResponse, or None if waiting for user input. """ + run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + updates: list[AgentRunResponseUpdate] = [] user_input_requests: list[FunctionApprovalRequestContent] = [] async for update in self._agent.run_stream( self._cache, thread=self._agent_thread, + **run_kwargs, ): updates.append(update) await ctx.add_event(AgentRunUpdateEvent(self.id, update)) diff --git a/python/packages/core/agent_framework/_workflows/_concurrent.py b/python/packages/core/agent_framework/_workflows/_concurrent.py index f6a7b09e60..2900254126 100644 --- a/python/packages/core/agent_framework/_workflows/_concurrent.py +++ b/python/packages/core/agent_framework/_workflows/_concurrent.py @@ -29,7 +29,8 @@ - a default aggregator that combines all agent conversations and completes the workflow Notes: -- Participants should be AgentProtocol instances or Executors. +- Participants can be provided as AgentProtocol or Executor instances via `.participants()`, + or as factories returning AgentProtocol or Executor via `.register_participants()`. - A custom aggregator can be provided as: - an Executor instance (it should handle list[AgentExecutorResponse], yield output), or @@ -189,8 +190,11 @@ class ConcurrentBuilder: r"""High-level builder for concurrent agent workflows. - `participants([...])` accepts a list of AgentProtocol (recommended) or Executor. + - `register_participants([...])` accepts a list of factories for AgentProtocol (recommended) + or Executor factories - `build()` wires: dispatcher -> fan-out -> participants -> fan-in -> aggregator. - - `with_custom_aggregator(...)` overrides the default aggregator with an Executor or callback. + - `with_aggregator(...)` overrides the default aggregator with an Executor or callback. + - `register_aggregator(...)` accepts a factory for an Executor as custom aggregator. Usage: @@ -201,14 +205,33 @@ class ConcurrentBuilder: # Minimal: use default aggregator (returns list[ChatMessage]) workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).build() + # With agent factories + workflow = ConcurrentBuilder().register_participants([create_agent1, create_agent2, create_agent3]).build() + # Custom aggregator via callback (sync or async). The callback receives # list[AgentExecutorResponse] and its return value becomes the workflow's output. - def summarize(results): + def summarize(results: list[AgentExecutorResponse]) -> str: return " | ".join(r.agent_run_response.messages[-1].text for r in results) - workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_custom_aggregator(summarize).build() + workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_aggregator(summarize).build() + + + # Custom aggregator via a factory + class MyAggregator(Executor): + @handler + async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results)) + + + workflow = ( + ConcurrentBuilder() + .register_participants([create_agent1, create_agent2, create_agent3]) + .register_aggregator(lambda: MyAggregator(id="my_aggregator")) + .build() + ) + # Enable checkpoint persistence so runs can resume workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_checkpointing(storage).build() @@ -219,10 +242,67 @@ def summarize(results): def __init__(self) -> None: self._participants: list[AgentProtocol | Executor] = [] + self._participant_factories: list[Callable[[], AgentProtocol | Executor]] = [] self._aggregator: Executor | None = None + self._aggregator_factory: Callable[[], Executor] | None = None self._checkpoint_storage: CheckpointStorage | None = None self._request_info_enabled: bool = False + def register_participants( + self, + participant_factories: Sequence[Callable[[], AgentProtocol | Executor]], + ) -> "ConcurrentBuilder": + r"""Define the parallel participants for this concurrent workflow. + + Accepts factories (callables) that return AgentProtocol instances (e.g., created + by a chat client) or Executor instances. Each participant created by a factory + is wired as a parallel branch using fan-out edges from an internal dispatcher. + + Args: + participant_factories: Sequence of callables returning AgentProtocol or Executor instances + + Raises: + ValueError: if `participant_factories` is empty or `.participants()` + or `.register_participants()` were already called + + Example: + + .. code-block:: python + + def create_researcher() -> ChatAgent: + return ... + + + def create_marketer() -> ChatAgent: + return ... + + + def create_legal() -> ChatAgent: + return ... + + + class MyCustomExecutor(Executor): ... + + + wf = ConcurrentBuilder().register_participants([create_researcher, create_marketer, create_legal]).build() + + # Mixing agent(s) and executor(s) is supported + wf2 = ConcurrentBuilder().register_participants([create_researcher, MyCustomExecutor]).build() + """ + if self._participants: + raise ValueError( + "Cannot mix .participants([...]) and .register_participants() in the same builder instance." + ) + + if self._participant_factories: + raise ValueError("register_participants() has already been called on this builder instance.") + + if not participant_factories: + raise ValueError("participant_factories cannot be empty") + + self._participant_factories = list(participant_factories) + return self + def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "ConcurrentBuilder": r"""Define the parallel participants for this concurrent workflow. @@ -230,8 +310,12 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Con instances. Each participant is wired as a parallel branch using fan-out edges from an internal dispatcher. + Args: + participants: Sequence of AgentProtocol or Executor instances + Raises: - ValueError: if `participants` is empty or contains duplicates + ValueError: if `participants` is empty, contains duplicates, or `.register_participants()` + or `.participants()` were already called TypeError: if any entry is not AgentProtocol or Executor Example: @@ -243,6 +327,14 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Con # Mixing agent(s) and executor(s) is supported wf2 = ConcurrentBuilder().participants([researcher_agent, my_custom_executor]).build() """ + if self._participant_factories: + raise ValueError( + "Cannot mix .participants([...]) and .register_participants() in the same builder instance." + ) + + if self._participants: + raise ValueError("participants() has already been called on this builder instance.") + if not participants: raise ValueError("participants cannot be empty") @@ -265,38 +357,107 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Con self._participants = list(participants) return self - def with_aggregator(self, aggregator: Executor | Callable[..., Any]) -> "ConcurrentBuilder": - r"""Override the default aggregator with an Executor or a callback. + def register_aggregator(self, aggregator_factory: Callable[[], Executor]) -> "ConcurrentBuilder": + r"""Define a custom aggregator for this concurrent workflow. + + Accepts a factory (callable) that returns an Executor instance. The executor + should handle `list[AgentExecutorResponse]` and yield output using `ctx.yield_output(...)`. + + Args: + aggregator_factory: Callable that returns an Executor instance + + Example: + .. code-block:: python + + class MyCustomExecutor(Executor): ... + + + wf = ( + ConcurrentBuilder() + .register_participants([create_researcher, create_marketer, create_legal]) + .register_aggregator(lambda: MyCustomExecutor(id="my_aggregator")) + .build() + ) + """ + if self._aggregator is not None: + raise ValueError( + "Cannot mix .with_aggregator(...) and .register_aggregator(...) in the same builder instance." + ) + + if self._aggregator_factory is not None: + raise ValueError("register_aggregator() has already been called on this builder instance.") + + self._aggregator_factory = aggregator_factory + return self + + def with_aggregator( + self, + aggregator: Executor + | Callable[[list[AgentExecutorResponse]], Any] + | Callable[[list[AgentExecutorResponse], WorkflowContext[Never, Any]], Any], + ) -> "ConcurrentBuilder": + r"""Override the default aggregator with an executor or a callback. - - Executor: must handle `list[AgentExecutorResponse]` and - yield output using `ctx.yield_output(...)` and add a - output and the workflow becomes idle. + - Executor: must handle `list[AgentExecutorResponse]` and yield output using `ctx.yield_output(...)` - Callback: sync or async callable with one of the signatures: `(results: list[AgentExecutorResponse]) -> Any | None` or `(results: list[AgentExecutorResponse], ctx: WorkflowContext) -> Any | None`. If the callback returns a non-None value, it becomes the workflow's output. + Args: + aggregator: Executor instance, or callback function + Example: .. code-block:: python + # Executor-based aggregator + class CustomAggregator(Executor): + @handler + async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext) -> None: + await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results)) + + + wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(CustomAggregator()).build() + # Callback-based aggregator (string result) - async def summarize(results): + async def summarize(results: list[AgentExecutorResponse]) -> str: return " | ".join(r.agent_run_response.messages[-1].text for r in results) - wf = ConcurrentBuilder().participants([a1, a2, a3]).with_custom_aggregator(summarize).build() + wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(summarize).build() + + + # Callback-based aggregator (yield result) + async def summarize(results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results)) + + + wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(summarize).build() """ + if self._aggregator_factory is not None: + raise ValueError( + "Cannot mix .with_aggregator(...) and .register_aggregator(...) in the same builder instance." + ) + + if self._aggregator is not None: + raise ValueError("with_aggregator() has already been called on this builder instance.") + if isinstance(aggregator, Executor): self._aggregator = aggregator elif callable(aggregator): self._aggregator = _CallbackAggregator(aggregator) else: raise TypeError("aggregator must be an Executor or a callable") + return self def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "ConcurrentBuilder": - """Enable checkpoint persistence using the provided storage backend.""" + """Enable checkpoint persistence using the provided storage backend. + + Args: + checkpoint_storage: CheckpointStorage instance for persisting workflow state + """ self._checkpoint_storage = checkpoint_storage return self @@ -329,7 +490,7 @@ def build(self) -> Workflow: before sending the outputs to the aggregator - Aggregator yields output and the workflow becomes idle. The output is either: - list[ChatMessage] (default aggregator: one user + one assistant per agent) - - custom payload from the provided callback/executor + - custom payload from the provided aggregator Returns: Workflow: a ready-to-run workflow instance @@ -343,25 +504,46 @@ def build(self) -> Workflow: workflow = ConcurrentBuilder().participants([agent1, agent2]).build() """ - if not self._participants: - raise ValueError("No participants provided. Call .participants([...]) first.") + if not self._participants and not self._participant_factories: + raise ValueError( + "No participants provided. Call .participants([...]) or .register_participants([...]) first." + ) + # Internal nodes dispatcher = _DispatchToAllParticipants(id="dispatcher") - aggregator = self._aggregator or _AggregateAgentConversations(id="aggregator") + aggregator = ( + self._aggregator + if self._aggregator is not None + else ( + self._aggregator_factory() + if self._aggregator_factory is not None + else _AggregateAgentConversations(id="aggregator") + ) + ) + + participants: list[Executor | AgentProtocol] = [] + if self._participant_factories: + # Resolve the participant factories now. This doesn't break the factory pattern + # since the Concurrent builder still creates new instances per workflow build. + for factory in self._participant_factories: + p = factory() + participants.append(p) + else: + participants = self._participants builder = WorkflowBuilder() builder.set_start_executor(dispatcher) - builder.add_fan_out_edges(dispatcher, list(self._participants)) + builder.add_fan_out_edges(dispatcher, participants) if self._request_info_enabled: # Insert interceptor between fan-in and aggregator # participants -> fan-in -> interceptor -> aggregator request_info_interceptor = RequestInfoInterceptor(executor_id="request_info") - builder.add_fan_in_edges(list(self._participants), request_info_interceptor) + builder.add_fan_in_edges(participants, request_info_interceptor) builder.add_edge(request_info_interceptor, aggregator) else: # Direct fan-in to aggregator - builder.add_fan_in_edges(list(self._participants), aggregator) + builder.add_fan_in_edges(participants, aggregator) if self._checkpoint_storage is not None: builder = builder.with_checkpointing(self._checkpoint_storage) diff --git a/python/packages/core/agent_framework/_workflows/_const.py b/python/packages/core/agent_framework/_workflows/_const.py index 6247be338a..34bde1da47 100644 --- a/python/packages/core/agent_framework/_workflows/_const.py +++ b/python/packages/core/agent_framework/_workflows/_const.py @@ -9,6 +9,11 @@ # Source identifier for internal workflow messages. INTERNAL_SOURCE_PREFIX = "internal" +# SharedState key for storing run kwargs that should be passed to agent invocations. +# Used by all orchestration patterns (Sequential, Concurrent, GroupChat, Handoff, Magentic) +# to pass kwargs from workflow.run_stream() through to agent.run_stream() and @ai_function tools. +WORKFLOW_RUN_KWARGS_KEY = "_workflow_run_kwargs" + def INTERNAL_SOURCE_ID(executor_id: str) -> str: """Generate an internal source ID for a given executor.""" diff --git a/python/packages/core/agent_framework/_workflows/_group_chat.py b/python/packages/core/agent_framework/_workflows/_group_chat.py index eab720a4b1..725a5c829c 100644 --- a/python/packages/core/agent_framework/_workflows/_group_chat.py +++ b/python/packages/core/agent_framework/_workflows/_group_chat.py @@ -132,7 +132,11 @@ class ManagerSelectionResponse(BaseModel): final_message: Optional final message string when finishing conversation (will be converted to ChatMessage) """ - model_config = {"extra": "forbid"} + model_config = { + "extra": "forbid", + # OpenAI strict mode requires all properties to be in required array + "json_schema_extra": {"required": ["selected_participant", "instruction", "finish", "final_message"]}, + } selected_participant: str | None = None instruction: str | None = None diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 8e0a7aec1e..9a99657902 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -14,7 +14,10 @@ Key properties: - The entire conversation is maintained and reused on every hop - The coordinator signals a handoff by invoking a tool call that names the specialist -- After a specialist responds, the workflow immediately requests new user input +- In human_in_loop mode (default), the workflow requests user input after each agent response + that doesn't trigger a handoff +- In autonomous mode, agents continue responding until they invoke a handoff tool or reach + a termination condition or turn limit """ import logging @@ -76,9 +79,9 @@ def _create_handoff_tool(alias: str, description: str | None = None) -> AIFuncti # Note: approval_mode is intentionally NOT set for handoff tools. # Handoff tools are framework-internal signals that trigger routing logic, - # not actual function executions. They are automatically intercepted and - # never actually execute, so approval is unnecessary and causes issues - # with tool_calls/responses pairing when cleaning conversations. + # not actual function executions. They are automatically intercepted by + # _AutoHandoffMiddleware which short-circuits execution and provides synthetic + # results, so the function body never actually runs in practice. @ai_function(name=tool_name, description=doc) def _handoff_tool(context: str | None = None) -> str: """Return a deterministic acknowledgement that encodes the target alias.""" @@ -109,6 +112,8 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent: chat_message_store_factory=agent.chat_message_store_factory, context_providers=agent.context_provider, middleware=middleware, + # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. + allow_multiple_tool_calls=False, frequency_penalty=options.frequency_penalty, logit_bias=dict(options.logit_bias) if options.logit_bias else None, max_tokens=options.max_tokens, @@ -130,19 +135,57 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent: @dataclass class HandoffUserInputRequest: - """Request message emitted when the workflow needs fresh user input.""" + """Request message emitted when the workflow needs fresh user input. + + Note: The conversation field is intentionally excluded from checkpoint serialization + to prevent duplication. The conversation is preserved in the coordinator's state + and will be reconstructed on restore. See issue #2667. + """ conversation: list[ChatMessage] awaiting_agent_id: str prompt: str source_executor_id: str + def to_dict(self) -> dict[str, Any]: + """Serialize to dict, excluding conversation to prevent checkpoint duplication. + + The conversation is already preserved in the workflow coordinator's state. + Including it here would cause duplicate messages when restoring from checkpoint. + """ + return { + "awaiting_agent_id": self.awaiting_agent_id, + "prompt": self.prompt, + "source_executor_id": self.source_executor_id, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "HandoffUserInputRequest": + """Deserialize from dict, initializing conversation as empty. + + The conversation will be reconstructed from the coordinator's state on restore. + """ + return cls( + conversation=[], + awaiting_agent_id=data["awaiting_agent_id"], + prompt=data["prompt"], + source_executor_id=data["source_executor_id"], + ) + @dataclass class _ConversationWithUserInput: - """Internal message carrying full conversation + new user messages from gateway to coordinator.""" + """Internal message carrying full conversation + new user messages from gateway to coordinator. + + Attributes: + full_conversation: The conversation messages to process. + is_post_restore: If True, indicates this message was created after a checkpoint restore. + The coordinator should append these messages to its existing conversation rather + than replacing it. This prevents duplicate messages (see issue #2667). + """ full_conversation: list[ChatMessage] = field(default_factory=lambda: []) # type: ignore[misc] + is_post_restore: bool = False @dataclass @@ -179,7 +222,7 @@ async def process( class _InputToConversation(Executor): - """Normalises initial workflow input into a list[ChatMessage].""" + """Normalizes initial workflow input into a list[ChatMessage].""" @handler async def from_str(self, prompt: str, ctx: WorkflowContext[list[ChatMessage]]) -> None: @@ -187,16 +230,12 @@ async def from_str(self, prompt: str, ctx: WorkflowContext[list[ChatMessage]]) - await ctx.send_message([ChatMessage(Role.USER, text=prompt)]) @handler - async def from_message(self, message: ChatMessage, ctx: WorkflowContext[list[ChatMessage]]) -> None: # type: ignore[name-defined] + async def from_message(self, message: ChatMessage, ctx: WorkflowContext[list[ChatMessage]]) -> None: """Pass through an existing chat message as the initial conversation.""" await ctx.send_message([message]) @handler - async def from_messages( - self, - messages: list[ChatMessage], - ctx: WorkflowContext[list[ChatMessage]], - ) -> None: # type: ignore[name-defined] + async def from_messages(self, messages: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: """Forward a list of chat messages as the starting conversation history.""" await ctx.send_message(list(messages)) @@ -362,7 +401,8 @@ async def handle_agent_response( self._conversation = list(full_conv) else: # Subsequent responses - append only new messages from this agent - # Keep ALL messages including tool calls to maintain complete history + # Keep ALL messages including tool calls to maintain complete history. + # This includes assistant messages with function calls and tool role messages with results. new_messages = response.agent_run_response.messages or [] self._conversation.extend(new_messages) @@ -439,9 +479,25 @@ async def handle_user_input( message: _ConversationWithUserInput, ctx: WorkflowContext[AgentExecutorRequest, list[ChatMessage]], ) -> None: - """Receive full conversation with new user input from gateway, update history, trim for agent.""" - # Update authoritative conversation - self._conversation = list(message.full_conversation) + """Receive user input from gateway, update history, and route to agent. + + The message.full_conversation may contain: + - Full conversation history + new user messages (normal flow) + - Only new user messages (post-checkpoint-restore flow, see issue #2667) + + The gateway sets message.is_post_restore=True when resuming after a checkpoint + restore. In that case, we append the new messages to the existing conversation + rather than replacing it. + """ + incoming = message.full_conversation + + if message.is_post_restore and self._conversation: + # Post-restore: append new user messages to existing conversation + # The coordinator already has its conversation restored from checkpoint + self._conversation.extend(incoming) + else: + # Normal flow: replace with full conversation + self._conversation = list(incoming) if incoming else self._conversation # Reset autonomous turn counter on new user input self._autonomous_turns = 0 @@ -462,9 +518,9 @@ async def handle_user_input( ) else: logger.info(f"Routing user input to coordinator '{target_agent_id}'") - # Note: Stack is only used for specialist-to-specialist handoffs, not user input routing - # Clean before sending to target agent + # Clean conversation before sending to target agent + # Removes tool-related messages that shouldn't be resent on every turn cleaned = clean_conversation_for_handoff(self._conversation) request = AgentExecutorRequest(messages=cleaned, should_respond=True) await ctx.send_message(request, target_id=target_agent_id) @@ -581,13 +637,7 @@ def _apply_response_metadata(self, conversation: list[ChatMessage], agent_respon class _UserInputGateway(Executor): """Bridges conversation context with the request & response cycle and re-enters the loop.""" - def __init__( - self, - *, - starting_agent_id: str, - prompt: str | None, - id: str, - ) -> None: + def __init__(self, *, starting_agent_id: str, prompt: str | None, id: str) -> None: """Initialise the gateway that requests user input and forwards responses.""" super().__init__(id) self._starting_agent_id = starting_agent_id @@ -626,20 +676,39 @@ async def resume_from_user( response: object, ctx: WorkflowContext[_ConversationWithUserInput], ) -> None: - """Convert user input responses back into chat messages and resume the workflow.""" - # Reconstruct full conversation with new user input - conversation = list(original_request.conversation) + """Convert user input responses back into chat messages and resume the workflow. + + After checkpoint restore, original_request.conversation will be empty (not serialized + to prevent duplication - see issue #2667). In this case, we send only the new user + messages and let the coordinator append them to its already-restored conversation. + """ user_messages = _as_user_messages(response) - conversation.extend(user_messages) - # Send full conversation back to coordinator (not trimmed) - # Coordinator will update its authoritative history and trim for agent - message = _ConversationWithUserInput(full_conversation=conversation) + if original_request.conversation: + # Normal flow: have conversation history from the original request + conversation = list(original_request.conversation) + conversation.extend(user_messages) + message = _ConversationWithUserInput(full_conversation=conversation, is_post_restore=False) + else: + # Post-restore flow: conversation was not serialized, send only new user messages + # The coordinator will append these to its already-restored conversation + message = _ConversationWithUserInput(full_conversation=user_messages, is_post_restore=True) + await ctx.send_message(message, target_id="handoff-coordinator") def _as_user_messages(payload: Any) -> list[ChatMessage]: - """Normalise arbitrary payloads into user-authored chat messages.""" + """Normalize arbitrary payloads into user-authored chat messages. + + Handles various input formats: + - ChatMessage instances (converted to USER role if needed) + - List of ChatMessage instances + - Mapping with 'text' or 'content' key + - Any other type (converted to string) + + Returns: + List of ChatMessage instances with USER role. + """ if isinstance(payload, ChatMessage): if payload.role == Role.USER: return [payload] @@ -735,7 +804,7 @@ class HandoffBuilder: name="customer_support", participants=[coordinator, refund, shipping], ) - .set_coordinator("coordinator_agent") + .set_coordinator(coordinator) .build() ) @@ -754,7 +823,7 @@ class HandoffBuilder: # Enable specialist-to-specialist handoffs with fluent API workflow = ( HandoffBuilder(participants=[coordinator, replacement, delivery, billing]) - .set_coordinator("coordinator_agent") + .set_coordinator(coordinator) .add_handoff(coordinator, [replacement, delivery, billing]) # Coordinator routes to all .add_handoff(replacement, [delivery, billing]) # Replacement delegates to delivery/billing .add_handoff(delivery, billing) # Delivery escalates to billing @@ -764,6 +833,35 @@ class HandoffBuilder: # Flow: User → Coordinator → Replacement → Delivery → Back to User # (Replacement hands off to Delivery without returning to user) + **Use Participant Factories for State Isolation:** + + .. code-block:: python + # Define factories that produce fresh agent instances per workflow run + def create_coordinator() -> AgentProtocol: + return chat_client.create_agent( + instructions="You are the coordinator agent...", + name="coordinator_agent", + ) + + + def create_specialist() -> AgentProtocol: + return chat_client.create_agent( + instructions="You are the specialist agent...", + name="specialist_agent", + ) + + + workflow = ( + HandoffBuilder( + participant_factories={ + "coordinator": create_coordinator, + "specialist": create_specialist, + } + ) + .set_coordinator("coordinator") + .build() + ) + **Custom Termination Condition:** .. code-block:: python @@ -771,7 +869,7 @@ class HandoffBuilder: # Terminate when user says goodbye or after 5 exchanges workflow = ( HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator("coordinator_agent") + .set_coordinator(coordinator) .with_termination_condition( lambda conv: sum(1 for msg in conv if msg.role.value == "user") >= 5 or any("goodbye" in msg.text.lower() for msg in conv[-2:]) @@ -788,7 +886,7 @@ class HandoffBuilder: storage = InMemoryCheckpointStorage() workflow = ( HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator("coordinator_agent") + .set_coordinator(coordinator) .with_checkpointing(storage) .build() ) @@ -797,6 +895,9 @@ class HandoffBuilder: name: Optional workflow name for identification and logging. participants: List of agents (AgentProtocol) or executors to participate in the handoff. The first agent you specify as coordinator becomes the orchestrating agent. + participant_factories: Mapping of factory names to callables that produce agents or + executors when invoked. This allows for lazy instantiation + and state isolation per workflow instance created by this builder. description: Optional human-readable description of the workflow. Raises: @@ -809,14 +910,16 @@ def __init__( *, name: str | None = None, participants: Sequence[AgentProtocol | Executor] | None = None, + participant_factories: Mapping[str, Callable[[], AgentProtocol | Executor]] | None = None, description: str | None = None, ) -> None: r"""Initialize a HandoffBuilder for creating conversational handoff workflows. The builder starts in an unconfigured state and requires you to call: 1. `.participants([...])` - Register agents - 2. `.set_coordinator(...)` - Designate which agent receives initial user input - 3. `.build()` - Construct the final Workflow + 2. or `.participant_factories({...})` - Register agent/executor factories + 3. `.set_coordinator(...)` - Designate which agent receives initial user input + 4. `.build()` - Construct the final Workflow Optional configuration methods allow you to customize context management, termination logic, and persistence. @@ -828,6 +931,9 @@ def __init__( participate in the handoff workflow. You can also call `.participants([...])` later. Each participant must have a unique identifier (name for agents, id for executors). + participant_factories: Optional mapping of factory names to callables that produce agents or + executors when invoked. This allows for lazy instantiation + and state isolation per workflow instance created by this builder. description: Optional human-readable description explaining the workflow's purpose. Useful for documentation and observability. @@ -848,7 +954,6 @@ def __init__( self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] = ( _default_termination_condition ) - self._auto_register_handoff_tools: bool = True self._handoff_config: dict[str, list[str]] = {} # Maps agent_id -> [target_agent_ids] self._return_to_previous: bool = False self._interaction_mode: Literal["human_in_loop", "autonomous"] = "human_in_loop" @@ -856,9 +961,79 @@ def __init__( self._request_info_enabled: bool = False self._request_info_filter: set[str] | None = None + self._participant_factories: dict[str, Callable[[], AgentProtocol | Executor]] = {} + if participant_factories: + self.participant_factories(participant_factories) + if participants: self.participants(participants) + # region Fluent Configuration Methods + + def participant_factories( + self, participant_factories: Mapping[str, Callable[[], AgentProtocol | Executor]] + ) -> "HandoffBuilder": + """Register factories that produce agents or executors for the handoff workflow. + + Each factory is a callable that returns an AgentProtocol or Executor instance. + Factories are invoked when building the workflow, allowing for lazy instantiation + and state isolation per workflow instance. + + Args: + participant_factories: Mapping of factory names to callables that return AgentProtocol or Executor + instances. Each produced participant must have a unique identifier (name for + agents, id for executors). + + Returns: + Self for method chaining. + + Raises: + ValueError: If participant_factories is empty or `.participants(...)` or `.participant_factories(...)` + has already been called. + + Example: + .. code-block:: python + + from agent_framework import ChatAgent, HandoffBuilder + + + def create_coordinator() -> ChatAgent: + return ... + + + def create_refund_agent() -> ChatAgent: + return ... + + + def create_billing_agent() -> ChatAgent: + return ... + + + factories = { + "coordinator": create_coordinator, + "refund": create_refund_agent, + "billing": create_billing_agent, + } + + builder = HandoffBuilder().participant_factories(factories) + # Use the factory IDs to create handoffs and set the coordinator + builder.add_handoff("coordinator", ["refund", "billing"]) + builder.set_coordinator("coordinator") + """ + if self._executors: + raise ValueError( + "Cannot mix .participants([...]) and .participant_factories() in the same builder instance." + ) + + if self._participant_factories: + raise ValueError("participant_factories() has already been called on this builder instance.") + + if not participant_factories: + raise ValueError("participant_factories cannot be empty") + + self._participant_factories = dict(participant_factories) + return self + def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "HandoffBuilder": """Register the agents or executors that will participate in the handoff workflow. @@ -875,7 +1050,8 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han Self for method chaining. Raises: - ValueError: If participants is empty or contains duplicates. + ValueError: If participants is empty, contains duplicates, or `.participants(...)` or + `.participant_factories(...)` has already been called. TypeError: If participants are not AgentProtocol or Executor instances. Example: @@ -897,26 +1073,28 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han This method resets any previously configured coordinator, so you must call `.set_coordinator(...)` again after changing participants. """ + if self._participant_factories: + raise ValueError( + "Cannot mix .participants([...]) and .participant_factories() in the same builder instance." + ) + + if self._executors: + raise ValueError("participants have already been assigned") + if not participants: raise ValueError("participants cannot be empty") named: dict[str, AgentProtocol | Executor] = {} for participant in participants: - identifier: str if isinstance(participant, Executor): identifier = participant.id elif isinstance(participant, AgentProtocol): - name_attr = getattr(participant, "name", None) - if not name_attr: - raise ValueError( - "Agents used in handoff workflows must have a stable name " - "so they can be addressed during routing." - ) - identifier = str(name_attr) + identifier = participant.display_name else: raise TypeError( f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." ) + if identifier in named: raise ValueError(f"Duplicate participant name '{identifier}' detected") named[identifier] = participant @@ -927,15 +1105,10 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han ) wrapped = metadata["executors"] - seen_ids: set[str] = set() - for executor in wrapped.values(): - if executor.id in seen_ids: - raise ValueError(f"Duplicate participant with id '{executor.id}' detected") - seen_ids.add(executor.id) - self._executors = {executor.id: executor for executor in wrapped.values()} self._aliases = metadata["aliases"] self._starting_agent_id = None + return self def set_coordinator(self, agent: str | AgentProtocol | Executor) -> "HandoffBuilder": @@ -952,7 +1125,7 @@ def set_coordinator(self, agent: str | AgentProtocol | Executor) -> "HandoffBuil Args: agent: The agent to use as the coordinator. Can be: - - Agent name (str): e.g., "coordinator_agent" + - Factory name (str): If using participant factories - AgentProtocol instance: The actual agent object - Executor instance: A custom executor wrapping an agent @@ -960,15 +1133,26 @@ def set_coordinator(self, agent: str | AgentProtocol | Executor) -> "HandoffBuil Self for method chaining. Raises: - ValueError: If participants(...) hasn't been called yet, or if the specified - agent is not in the participants list. + ValueError: 1) If `agent` is an AgentProtocol or Executor instance but `.participants(...)` hasn't + been called yet, or if it is not in the participants list. + 2) If `agent` is a factory name (str) but `.participant_factories(...)` hasn't been + called yet, or if it is not in the participant_factories list. + TypeError: If `agent` is not a str, AgentProtocol, or Executor instance. Example: .. code-block:: python - # Use agent name - builder = HandoffBuilder().participants([coordinator, refund, billing]).set_coordinator("coordinator") + # Use factory name with `.participant_factories()` + builder = ( + HandoffBuilder() + .participant_factories({ + "coordinator": create_coordinator, + "refund": create_refund_agent, + "billing": create_billing_agent, + }) + .set_coordinator("coordinator") + ) # Or pass the agent object directly builder = HandoffBuilder().participants([coordinator, refund, billing]).set_coordinator(coordinator) @@ -979,12 +1163,29 @@ def set_coordinator(self, agent: str | AgentProtocol | Executor) -> "HandoffBuil Decorate the tool with `approval_mode="always_require"` to ensure the workflow intercepts the call before execution and can make the transition. """ - if not self._executors: - raise ValueError("Call participants(...) before coordinator(...)") - resolved = self._resolve_to_id(agent) - if resolved not in self._executors: - raise ValueError(f"coordinator '{resolved}' is not part of the participants list") - self._starting_agent_id = resolved + if isinstance(agent, (AgentProtocol, Executor)): + if not self._executors: + raise ValueError( + "Call participants(...) before coordinator(...). If using participant_factories, " + "pass the factory name (str) instead of the agent instance." + ) + resolved = self._resolve_to_id(agent) + if resolved not in self._executors: + raise ValueError(f"coordinator '{resolved}' is not part of the participants list") + self._starting_agent_id = resolved + elif isinstance(agent, str): + if agent not in self._participant_factories: + raise ValueError( + f"coordinator factory name '{agent}' is not part of the participant_factories list. If " + "you are using participant instances, call .participants(...) and pass the agent instance instead." + ) + self._starting_agent_id = agent + else: + raise TypeError( + "coordinator must be a factory name (str), AgentProtocol, or Executor instance. " + f"Got {type(agent).__name__}." + ) + return self def add_handoff( @@ -1004,33 +1205,42 @@ def add_handoff( Args: source: The agent that can initiate the handoff. Can be: - - Agent name (str): e.g., "triage_agent" + - Factory name (str): If using participant factories - AgentProtocol instance: The actual agent object - Executor instance: A custom executor wrapping an agent + - Cannot mix factory names and instances across source and targets targets: One or more target agents that the source can hand off to. Can be: - - Single agent: "billing_agent" or agent_instance - - Multiple agents: ["billing_agent", "support_agent"] or [agent1, agent2] - tool_name: Optional custom name for the handoff tool. If not provided, generates - "handoff_to_" for single targets or "handoff_to__agent" - for multiple targets based on target names. - tool_description: Optional custom description for the handoff tool. If not provided, - generates "Handoff to the agent." + - Factory name (str): If using participant factories + - AgentProtocol instance: The actual agent object + - Executor instance: A custom executor wrapping an agent + - Single target: "billing_agent" or agent_instance + - Multiple targets: ["billing_agent", "support_agent"] or [agent1, agent2] + - Cannot mix factory names and instances across source and targets + tool_name: Optional custom name for the handoff tool. Currently not used in the + implementation - tools are always auto-generated as "handoff_to_". + Reserved for future enhancement. + tool_description: Optional custom description for the handoff tool. Currently not used + in the implementation - descriptions are always auto-generated as + "Handoff to the agent.". Reserved for future enhancement. Returns: Self for method chaining. Raises: - ValueError: If source or targets are not in the participants list, or if + ValueError: 1) If source or targets are not in the participants list, or if participants(...) hasn't been called yet. + 2) If source or targets are factory names (str) but participant_factories(...) + hasn't been called yet, or if they are not in the participant_factories list. + TypeError: If mixing factory names (str) and AgentProtocol/Executor instances Examples: - Single target: + Single target (using factory name): .. code-block:: python builder.add_handoff("triage_agent", "billing_agent") - Multiple targets (using agent names): + Multiple targets (using factory names): .. code-block:: python @@ -1055,138 +1265,70 @@ def add_handoff( .build() ) - Custom tool names and descriptions: - - .. code-block:: python - - builder.add_handoff( - "support_agent", - "escalation_agent", - tool_name="escalate_to_l2", - tool_description="Escalate this issue to Level 2 support", - ) - Note: - Handoff tools are automatically registered for each source agent - If a source agent is configured multiple times via add_handoff, targets are merged """ - if not self._executors: - raise ValueError("Call participants(...) before add_handoff(...)") - - # Resolve source agent ID - source_id = self._resolve_to_id(source) - if source_id not in self._executors: - raise ValueError(f"Source agent '{source}' is not in the participants list") - - # Normalize targets to list - target_list = [targets] if isinstance(targets, (str, AgentProtocol, Executor)) else list(targets) - - # Resolve all target IDs - target_ids: list[str] = [] - for target in target_list: - target_id = self._resolve_to_id(target) - if target_id not in self._executors: - raise ValueError(f"Target agent '{target}' is not in the participants list") - target_ids.append(target_id) - - # Merge with existing handoff configuration for this source - if source_id in self._handoff_config: - # Add new targets to existing list, avoiding duplicates - existing = self._handoff_config[source_id] - for target_id in target_ids: - if target_id not in existing: - existing.append(target_id) - else: - self._handoff_config[source_id] = target_ids - - return self - - def auto_register_handoff_tools(self, enabled: bool) -> "HandoffBuilder": - """Configure whether the builder should synthesize handoff tools for the starting agent.""" - self._auto_register_handoff_tools = enabled - return self - - def _apply_auto_tools(self, agent: ChatAgent, specialists: Mapping[str, Executor]) -> dict[str, str]: - """Attach synthetic handoff tools to a chat agent and return the target lookup table.""" - chat_options = agent.chat_options - existing_tools = list(chat_options.tools or []) - existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")} - - tool_targets: dict[str, str] = {} - new_tools: list[Any] = [] - for exec_id in specialists: - alias = exec_id - sanitized = sanitize_identifier(alias) - tool = _create_handoff_tool(alias) - if tool.name not in existing_names: - new_tools.append(tool) - tool_targets[tool.name.lower()] = exec_id - tool_targets[sanitized] = exec_id - tool_targets[alias.lower()] = exec_id - - if new_tools: - chat_options.tools = existing_tools + new_tools - else: - chat_options.tools = existing_tools - - return tool_targets - - def _resolve_agent_id(self, agent_identifier: str) -> str: - """Resolve an agent identifier to an executor ID. - - Args: - agent_identifier: Can be agent name, display name, or executor ID - - Returns: - The executor ID - - Raises: - ValueError: If the identifier cannot be resolved - """ - # Check if it's already an executor ID - if agent_identifier in self._executors: - return agent_identifier - - # Check if it's an alias - if agent_identifier in self._aliases: - return self._aliases[agent_identifier] - - # Not found - raise ValueError(f"Agent identifier '{agent_identifier}' not found in participants") - - def _prepare_agent_with_handoffs( - self, - executor: AgentExecutor, - target_agents: Mapping[str, Executor], - ) -> tuple[AgentExecutor, dict[str, str]]: - """Prepare an agent by adding handoff tools for the specified target agents. - - Args: - executor: The agent executor to prepare - target_agents: Map of executor IDs to target executors this agent can hand off to - - Returns: - Tuple of (updated executor, tool_targets map) - """ - agent = getattr(executor, "_agent", None) - if not isinstance(agent, ChatAgent): - return executor, {} + if isinstance(source, str) and ( + isinstance(targets, str) or (isinstance(targets, Sequence) and all(isinstance(t, str) for t in targets)) + ): + # Both source and targets are factory names + if not self._participant_factories: + raise ValueError("Call participant_factories(...) before add_handoff(...)") + + if source not in self._participant_factories: + raise ValueError(f"Source factory name '{source}' is not in the participant_factories list") + + target_list: list[str] = [targets] if isinstance(targets, str) else list(targets) # type: ignore + for target in target_list: + if target not in self._participant_factories: + raise ValueError(f"Target factory name '{target}' is not in the participant_factories list") + + self._handoff_config[source] = target_list # type: ignore + return self + + if isinstance(source, (AgentProtocol, Executor)) and ( + isinstance(targets, (AgentProtocol, Executor)) + or all(isinstance(t, (AgentProtocol, Executor)) for t in targets) + ): + # Both source and targets are instances + if not self._executors: + raise ValueError("Call participants(...) before add_handoff(...)") + + # Resolve source agent ID + source_id = self._resolve_to_id(source) + if source_id not in self._executors: + raise ValueError(f"Source agent '{source}' is not in the participants list") + + # Normalize targets to list + target_list: list[AgentProtocol | Executor] = ( # type: ignore[no-redef] + [targets] if isinstance(targets, (AgentProtocol, Executor)) else list(targets) + ) # type: ignore + + # Resolve all target IDs + target_ids: list[str] = [] + for target in target_list: + target_id = self._resolve_to_id(target) + if target_id not in self._executors: + raise ValueError(f"Target agent '{target}' is not in the participants list") + target_ids.append(target_id) + + # Merge with existing handoff configuration for this source + if source_id in self._handoff_config: + # Add new targets to existing list, avoiding duplicates + existing = self._handoff_config[source_id] + for target_id in target_ids: + if target_id not in existing: + existing.append(target_id) + else: + self._handoff_config[source_id] = target_ids - cloned_agent = _clone_chat_agent(agent) - tool_targets = self._apply_auto_tools(cloned_agent, target_agents) - if tool_targets: - middleware = _AutoHandoffMiddleware(tool_targets) - existing_middleware = list(cloned_agent.middleware or []) - existing_middleware.append(middleware) - cloned_agent.middleware = existing_middleware + return self - new_executor = AgentExecutor( - cloned_agent, - agent_thread=getattr(executor, "_agent_thread", None), - output_response=getattr(executor, "_output_response", False), - id=executor.id, + raise TypeError( + "Cannot mix factory names (str) and AgentProtocol/Executor instances " + "across source and targets in add_handoff()" ) - return new_executor, tool_targets def request_prompt(self, prompt: str | None) -> "HandoffBuilder": """Set a custom prompt message displayed when requesting user input. @@ -1548,75 +1690,46 @@ def build(self) -> Workflow: After calling build(), the builder instance should not be reused. Create a new builder if you need to construct another workflow with different configuration. """ - if not self._executors: - raise ValueError("No participants provided. Call participants([...]) first.") - if self._starting_agent_id is None: - raise ValueError("coordinator must be defined before build().") + if not self._executors and not self._participant_factories: + raise ValueError( + "No participants or participant_factories have been configured. " + "Call participants(...) or participant_factories(...) first." + ) - starting_executor = self._executors[self._starting_agent_id] - specialists = { - exec_id: executor for exec_id, executor in self._executors.items() if exec_id != self._starting_agent_id - } + if self._starting_agent_id is None: + raise ValueError("Must call set_coordinator(...) before building the workflow.") - # Build handoff tool registry for all agents that need them - handoff_tool_targets: dict[str, str] = {} - if self._auto_register_handoff_tools: - # Determine which agents should have handoff tools - if self._handoff_config: - # Use explicit handoff configuration from add_handoff() calls - for source_exec_id, target_exec_ids in self._handoff_config.items(): - executor = self._executors.get(source_exec_id) - if not executor: - raise ValueError(f"Handoff source agent '{source_exec_id}' not found in participants") - - if isinstance(executor, AgentExecutor): - # Build targets map for this source agent - targets_map: dict[str, Executor] = {} - for target_exec_id in target_exec_ids: - target_executor = self._executors.get(target_exec_id) - if not target_executor: - raise ValueError(f"Handoff target agent '{target_exec_id}' not found in participants") - targets_map[target_exec_id] = target_executor - - # Register handoff tools for this agent - updated_executor, tool_targets = self._prepare_agent_with_handoffs(executor, targets_map) - self._executors[source_exec_id] = updated_executor - handoff_tool_targets.update(tool_targets) - else: - # Default behavior: only coordinator gets handoff tools to all specialists - if isinstance(starting_executor, AgentExecutor) and specialists: - starting_executor, tool_targets = self._prepare_agent_with_handoffs(starting_executor, specialists) - self._executors[self._starting_agent_id] = starting_executor - handoff_tool_targets.update(tool_targets) # Update references after potential agent modifications - starting_executor = self._executors[self._starting_agent_id] - specialists = { - exec_id: executor for exec_id, executor in self._executors.items() if exec_id != self._starting_agent_id - } + # Resolve executors, aliases, and handoff tool targets + # This will instantiate participants if using factories, and validate handoff config + start_executor_id, executors, aliases, handoff_tool_targets = self._resolve_executors_and_handoffs() + specialists = {exec_id: executor for exec_id, executor in executors.items() if exec_id != start_executor_id} if not specialists: logger.warning("Handoff workflow has no specialist agents; the coordinator will loop with the user.") descriptions = { - exec_id: getattr(executor, "description", None) or exec_id for exec_id, executor in self._executors.items() + exec_id: getattr(executor, "description", None) or exec_id for exec_id, executor in executors.items() } participant_specs = { exec_id: GroupChatParticipantSpec(name=exec_id, participant=executor, description=descriptions[exec_id]) - for exec_id, executor in self._executors.items() + for exec_id, executor in executors.items() } input_node = _InputToConversation(id="input-conversation") user_gateway = _UserInputGateway( - starting_agent_id=starting_executor.id, + starting_agent_id=start_executor_id, prompt=self._request_prompt, id="handoff-user-input", ) builder = WorkflowBuilder(name=self._name, description=self._description).set_start_executor(input_node) - specialist_aliases = {alias: exec_id for alias, exec_id in self._aliases.items() if exec_id in specialists} + specialist_aliases = { + alias: specialists[exec_id].id for alias, exec_id in aliases.items() if exec_id in specialists + } def _handoff_orchestrator_factory(_: _GroupChatConfig) -> Executor: return _HandoffCoordinator( - starting_agent_id=starting_executor.id, + starting_agent_id=start_executor_id, specialist_ids=specialist_aliases, input_gateway_id=user_gateway.id, termination_condition=self._termination_condition, @@ -1633,8 +1746,8 @@ def _handoff_orchestrator_factory(_: _GroupChatConfig) -> Executor: manager_name=self._starting_agent_id, participants=participant_specs, max_rounds=None, - participant_aliases=self._aliases, - participant_executors=self._executors, + participant_aliases=aliases, + participant_executors=executors, ) # Determine participant factory - wrap with request info interceptor if enabled @@ -1683,14 +1796,159 @@ def _factory_with_request_info( builder = builder.add_edge(input_node, starting_entry_executor) else: # Fallback to direct connection if interceptor not found - builder = builder.add_edge(input_node, starting_executor) + builder = builder.add_edge(input_node, executors[start_executor_id]) else: - builder = builder.add_edge(input_node, starting_executor) + builder = builder.add_edge(input_node, executors[start_executor_id]) builder = builder.add_edge(coordinator, user_gateway) builder = builder.add_edge(user_gateway, coordinator) return builder.build() + # endregion Fluent Configuration Methods + + # region Internal Helper Methods + + def _resolve_executors(self) -> tuple[dict[str, Executor], dict[str, str]]: + """Resolve participant factories into executor instances. + + If executors were provided directly via participants(...), those are returned as-is. + If participant factories were provided via participant_factories(...), those + are invoked to create executor instances and aliases. + + Returns: + Tuple of (executors map, aliases map) + """ + if self._executors and self._participant_factories: + raise ValueError("Cannot have both executors and participant_factories configured") + + if self._executors: + if self._aliases: + # Return existing executors and aliases + return self._executors, self._aliases + raise ValueError("Aliases is empty despite executors being provided") + + if self._participant_factories: + # Invoke each factory to create participant instances + executor_ids_to_executors: dict[str, AgentProtocol | Executor] = {} + factory_names_to_ids: dict[str, str] = {} + for factory_name, factory in self._participant_factories.items(): + instance: Executor | AgentProtocol = factory() + if isinstance(instance, Executor): + identifier = instance.id + elif isinstance(instance, AgentProtocol): + identifier = instance.display_name + else: + raise TypeError( + f"Participants must be AgentProtocol or Executor instances. Got {type(instance).__name__}." + ) + + if identifier in executor_ids_to_executors: + raise ValueError(f"Duplicate participant name '{identifier}' detected") + executor_ids_to_executors[identifier] = instance + factory_names_to_ids[factory_name] = identifier + + # Prepare metadata and wrap instances as needed + metadata = prepare_participant_metadata( + executor_ids_to_executors, + description_factory=lambda name, participant: getattr(participant, "description", None) or name, + ) + + wrapped = metadata["executors"] + # Map executors by factory name (not executor.id) because handoff configs reference factory names + # This allows users to configure handoffs using the factory names they provided + executors = { + factory_name: wrapped[executor_id] for factory_name, executor_id in factory_names_to_ids.items() + } + aliases = metadata["aliases"] + + return executors, aliases + + raise ValueError("No executors or participant_factories have been configured") + + def _resolve_handoffs(self, executors: Mapping[str, Executor]) -> tuple[dict[str, Executor], dict[str, str]]: + """Handoffs may be specified using factory names or instances; resolve to executor IDs. + + Args: + executors: Map of executor IDs or factory names to Executor instances + + Returns: + Tuple of (updated executors map, handoff configuration map) + The updated executors map may have modified agents with handoff tools added + and maps executor IDs to Executor instances. + The handoff configuration map maps executor IDs to lists of target executor IDs. + """ + handoff_tool_targets: dict[str, str] = {} + updated_executors = {executor.id: executor for executor in executors.values()} + # Determine which agents should have handoff tools + if self._handoff_config: + # Use explicit handoff configuration from add_handoff() calls + for source_id, target_ids in self._handoff_config.items(): + executor = executors.get(source_id) + if not executor: + raise ValueError( + f"Handoff source agent '{source_id}' not found. " + "Please make sure source has been added as either a participant or participant_factory." + ) + + if isinstance(executor, AgentExecutor): + # Build targets map for this source agent + targets_map: dict[str, Executor] = {} + for target_id in target_ids: + target_executor = executors.get(target_id) + if not target_executor: + raise ValueError( + f"Handoff target agent '{target_id}' not found. " + "Please make sure target has been added as either a participant or participant_factory." + ) + targets_map[target_executor.id] = target_executor + + # Register handoff tools for this agent + updated_executor, tool_targets = self._prepare_agent_with_handoffs(executor, targets_map) + updated_executors[updated_executor.id] = updated_executor + handoff_tool_targets.update(tool_targets) + else: + if self._starting_agent_id is None or self._starting_agent_id not in executors: + raise RuntimeError("Failed to resolve default handoff configuration due to missing starting agent.") + + # Default behavior: only coordinator gets handoff tools to all specialists + starting_executor = executors[self._starting_agent_id] + specialists = { + executor.id: executor for executor in executors.values() if executor.id != starting_executor.id + } + + if isinstance(starting_executor, AgentExecutor) and specialists: + starting_executor, tool_targets = self._prepare_agent_with_handoffs(starting_executor, specialists) + updated_executors[starting_executor.id] = starting_executor + handoff_tool_targets.update(tool_targets) # Update references after potential agent modifications + + return updated_executors, handoff_tool_targets + + def _resolve_executors_and_handoffs(self) -> tuple[str, dict[str, Executor], dict[str, str], dict[str, str]]: + """Resolve participant factories into executor instances and handoff configurations. + + If executors were provided directly via participants(...), those are returned as-is. + If participant factories were provided via participant_factories(...), those + are invoked to create executor instances and aliases. + + Returns: + Tuple of (executors map, aliases map, handoff configuration map) + """ + # Resolve the participant factories now. This doesn't break the factory pattern + # since the Handoff builder still creates new instances per workflow build. + executors, aliases = self._resolve_executors() + # `self._starting_agent_id` is either a factory name or executor ID at this point, + # resolve to executor ID + if self._starting_agent_id in executors: + start_executor_id = executors[self._starting_agent_id].id + else: + raise RuntimeError("Failed to resolve starting agent ID during build.") + + # Resolve handoffs + # This will update the `executors` dict to a map of executor IDs to executors + updated_executors, handoff_tool_targets = self._resolve_handoffs(executors) + + return start_executor_id, updated_executors, aliases, handoff_tool_targets + def _resolve_to_id(self, candidate: str | AgentProtocol | Executor) -> str: """Resolve a participant reference into a concrete executor identifier.""" if isinstance(candidate, Executor): @@ -1705,3 +1963,77 @@ def _resolve_to_id(self, candidate: str | AgentProtocol | Executor) -> str: return self._aliases[candidate] return candidate raise TypeError(f"Invalid starting agent reference: {type(candidate).__name__}") + + def _apply_auto_tools(self, agent: ChatAgent, specialists: Mapping[str, Executor]) -> dict[str, str]: + """Attach synthetic handoff tools to a chat agent and return the target lookup table. + + Creates handoff tools for each specialist agent that this agent can route to. + The tool_targets dict maps various name formats (tool name, sanitized name, alias) + to executor IDs to enable flexible handoff target resolution. + + Args: + agent: The ChatAgent to add handoff tools to + specialists: Map of executor IDs or factory names to specialist executors this agent can hand off to + + Returns: + Dict mapping tool names (in various formats) to executor IDs for handoff resolution + """ + chat_options = agent.chat_options + existing_tools = list(chat_options.tools or []) + existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")} + + tool_targets: dict[str, str] = {} + new_tools: list[Any] = [] + for executor in specialists.values(): + alias = executor.id + sanitized = sanitize_identifier(alias) + tool = _create_handoff_tool(alias, executor.description if isinstance(executor, AgentExecutor) else None) + if tool.name not in existing_names: + new_tools.append(tool) + # Map multiple name variations to the same executor ID for robust resolution + tool_targets[tool.name.lower()] = executor.id + tool_targets[sanitized] = executor.id + tool_targets[alias.lower()] = executor.id + + if new_tools: + chat_options.tools = existing_tools + new_tools + else: + chat_options.tools = existing_tools + + return tool_targets + + def _prepare_agent_with_handoffs( + self, + executor: AgentExecutor, + target_agents: Mapping[str, Executor], + ) -> tuple[AgentExecutor, dict[str, str]]: + """Prepare an agent by adding handoff tools for the specified target agents. + + Args: + executor: The agent executor to prepare + target_agents: Map of executor IDs to target executors this agent can hand off to + + Returns: + Tuple of (updated executor, tool_targets map) + """ + agent = getattr(executor, "_agent", None) + if not isinstance(agent, ChatAgent): + return executor, {} + + cloned_agent = _clone_chat_agent(agent) + tool_targets = self._apply_auto_tools(cloned_agent, target_agents) + if tool_targets: + middleware = _AutoHandoffMiddleware(tool_targets) + existing_middleware = list(cloned_agent.middleware or []) + existing_middleware.append(middleware) + cloned_agent.middleware = existing_middleware + + new_executor = AgentExecutor( + cloned_agent, + agent_thread=getattr(executor, "_agent_thread", None), + output_response=getattr(executor, "_output_response", False), + id=executor.id, + ) + return new_executor, tool_targets + + # endregion Internal Helper Methods diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index a24fd77b16..cdbc79e0c0 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -25,7 +25,7 @@ from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator from ._checkpoint import CheckpointStorage, WorkflowCheckpoint -from ._const import EXECUTOR_STATE_KEY +from ._const import EXECUTOR_STATE_KEY, WORKFLOW_RUN_KWARGS_KEY from ._events import AgentRunUpdateEvent, WorkflowEvent from ._executor import Executor, handler from ._group_chat import ( @@ -286,12 +286,14 @@ class _MagenticStartMessage(DictConvertible): """Internal: A message to start a magentic workflow.""" messages: list[ChatMessage] = field(default_factory=_new_chat_message_list) + run_kwargs: dict[str, Any] = field(default_factory=dict) def __init__( self, messages: str | ChatMessage | Sequence[str] | Sequence[ChatMessage] | None = None, *, task: ChatMessage | None = None, + run_kwargs: dict[str, Any] | None = None, ) -> None: normalized = normalize_messages_input(messages) if task is not None: @@ -299,6 +301,7 @@ def __init__( if not normalized: raise ValueError("MagenticStartMessage requires at least one message input.") self.messages: list[ChatMessage] = normalized + self.run_kwargs: dict[str, Any] = run_kwargs or {} @property def task(self) -> ChatMessage: @@ -1179,6 +1182,10 @@ async def handle_start_message( return logger.info("Magentic Orchestrator: Received start message") + # Store run_kwargs in SharedState so agent executors can access them + # Always store (even empty dict) so retrieval is deterministic + await context.set_shared_state(WORKFLOW_RUN_KWARGS_KEY, message.run_kwargs or {}) + self._context = MagenticContext( task=message.task, participant_descriptions=self._participants, @@ -2004,10 +2011,12 @@ async def _invoke_agent( """ logger.debug(f"Agent {self._agent_id}: Running with {len(self._chat_history)} messages") + run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + updates: list[AgentRunResponseUpdate] = [] # The wrapped participant is guaranteed to be an BaseAgent when this is called. agent = cast("AgentProtocol", self._agent) - async for update in agent.run_stream(messages=self._chat_history): # type: ignore[attr-defined] + async for update in agent.run_stream(messages=self._chat_history, **run_kwargs): # type: ignore[attr-defined] updates.append(update) await self._emit_agent_delta_event(ctx, update) @@ -2604,38 +2613,48 @@ def workflow(self) -> Workflow: """Access the underlying workflow.""" return self._workflow - async def run_streaming_with_string(self, task_text: str) -> AsyncIterable[WorkflowEvent]: + async def run_streaming_with_string(self, task_text: str, **kwargs: Any) -> AsyncIterable[WorkflowEvent]: """Run the workflow with a task string. Args: task_text: The task description as a string. + **kwargs: Additional keyword arguments to pass through to agent invocations. + These kwargs will be available in @ai_function tools via **kwargs. Yields: WorkflowEvent: The events generated during the workflow execution. """ start_message = _MagenticStartMessage.from_string(task_text) + start_message.run_kwargs = kwargs async for event in self._workflow.run_stream(start_message): yield event - async def run_streaming_with_message(self, task_message: ChatMessage) -> AsyncIterable[WorkflowEvent]: + async def run_streaming_with_message( + self, task_message: ChatMessage, **kwargs: Any + ) -> AsyncIterable[WorkflowEvent]: """Run the workflow with a ChatMessage. Args: task_message: The task as a ChatMessage. + **kwargs: Additional keyword arguments to pass through to agent invocations. + These kwargs will be available in @ai_function tools via **kwargs. Yields: WorkflowEvent: The events generated during the workflow execution. """ - start_message = _MagenticStartMessage(task_message) + start_message = _MagenticStartMessage(task_message, run_kwargs=kwargs) async for event in self._workflow.run_stream(start_message): yield event - async def run_stream(self, message: Any | None = None) -> AsyncIterable[WorkflowEvent]: + async def run_stream(self, message: Any | None = None, **kwargs: Any) -> AsyncIterable[WorkflowEvent]: """Run the workflow with either a message object or the preset task string. Args: message: The message to send. If None and task_text was provided during construction, uses the preset task string. + **kwargs: Additional keyword arguments to pass through to agent invocations. + These kwargs will be available in @ai_function tools via **kwargs. + Example: workflow.run_stream("task", user_id="123", custom_data={...}) Yields: WorkflowEvent: The events generated during the workflow execution. @@ -2643,13 +2662,19 @@ async def run_stream(self, message: Any | None = None) -> AsyncIterable[Workflow if message is None: if self._task_text is None: raise ValueError("No message provided and no preset task text available") - message = _MagenticStartMessage.from_string(self._task_text) + start_message = _MagenticStartMessage.from_string(self._task_text) elif isinstance(message, str): - message = _MagenticStartMessage.from_string(message) + start_message = _MagenticStartMessage.from_string(message) elif isinstance(message, (ChatMessage, list)): - message = _MagenticStartMessage(message) # type: ignore[arg-type] + start_message = _MagenticStartMessage(message) # type: ignore[arg-type] + else: + start_message = message - async for event in self._workflow.run_stream(message): + # Attach kwargs to the start message + if isinstance(start_message, _MagenticStartMessage): + start_message.run_kwargs = kwargs + + async for event in self._workflow.run_stream(start_message): yield event async def _validate_checkpoint_participants( @@ -2730,46 +2755,49 @@ async def _validate_checkpoint_participants( f"Missing names: {missing}; unexpected names: {unexpected}." ) - async def run_with_string(self, task_text: str) -> WorkflowRunResult: + async def run_with_string(self, task_text: str, **kwargs: Any) -> WorkflowRunResult: """Run the workflow with a task string and return all events. Args: task_text: The task description as a string. + **kwargs: Additional keyword arguments to pass through to agent invocations. Returns: WorkflowRunResult: All events generated during the workflow execution. """ events: list[WorkflowEvent] = [] - async for event in self.run_streaming_with_string(task_text): + async for event in self.run_streaming_with_string(task_text, **kwargs): events.append(event) return WorkflowRunResult(events) - async def run_with_message(self, task_message: ChatMessage) -> WorkflowRunResult: + async def run_with_message(self, task_message: ChatMessage, **kwargs: Any) -> WorkflowRunResult: """Run the workflow with a ChatMessage and return all events. Args: task_message: The task as a ChatMessage. + **kwargs: Additional keyword arguments to pass through to agent invocations. Returns: WorkflowRunResult: All events generated during the workflow execution. """ events: list[WorkflowEvent] = [] - async for event in self.run_streaming_with_message(task_message): + async for event in self.run_streaming_with_message(task_message, **kwargs): events.append(event) return WorkflowRunResult(events) - async def run(self, message: Any | None = None) -> WorkflowRunResult: + async def run(self, message: Any | None = None, **kwargs: Any) -> WorkflowRunResult: """Run the workflow and return all events. Args: message: The message to send. If None and task_text was provided during construction, uses the preset task string. + **kwargs: Additional keyword arguments to pass through to agent invocations. Returns: WorkflowRunResult: All events generated during the workflow execution. """ events: list[WorkflowEvent] = [] - async for event in self.run_stream(message): + async for event in self.run_stream(message, **kwargs): events.append(event) return WorkflowRunResult(events) diff --git a/python/packages/core/agent_framework/_workflows/_participant_utils.py b/python/packages/core/agent_framework/_workflows/_participant_utils.py index ac632a917d..a6f1cf2a84 100644 --- a/python/packages/core/agent_framework/_workflows/_participant_utils.py +++ b/python/packages/core/agent_framework/_workflows/_participant_utils.py @@ -47,15 +47,13 @@ def wrap_participant(participant: AgentProtocol | Executor, *, executor_id: str """Represent `participant` as an `Executor`.""" if isinstance(participant, Executor): return participant + if not isinstance(participant, AgentProtocol): raise TypeError( f"Participants must implement AgentProtocol or be Executor instances. Got {type(participant).__name__}." ) - name = getattr(participant, "name", None) - if executor_id is None: - if not name: - raise ValueError("Agent participants must expose a stable 'name' attribute.") - executor_id = str(name) + + executor_id = executor_id or participant.display_name return AgentExecutor(participant, id=executor_id) diff --git a/python/packages/core/agent_framework/_workflows/_sequential.py b/python/packages/core/agent_framework/_workflows/_sequential.py index b6a49ecab8..24ae4cda29 100644 --- a/python/packages/core/agent_framework/_workflows/_sequential.py +++ b/python/packages/core/agent_framework/_workflows/_sequential.py @@ -4,7 +4,8 @@ This module provides a high-level, agent-focused API to assemble a sequential workflow where: -- Participants are a sequence of AgentProtocol instances or Executors +- Participants can be provided as AgentProtocol or Executor instances via `.participants()`, + or as factories returning AgentProtocol or Executor via `.register_participants()` - A shared conversation context (list[ChatMessage]) is passed along the chain - Agents append their assistant messages to the context - Custom executors can transform or summarize and return a refined context @@ -15,7 +16,7 @@ Notes: - Participants can mix AgentProtocol and Executor objects -- Agents are auto-wrapped by WorkflowBuilder as AgentExecutor +- Agents are auto-wrapped by WorkflowBuilder as AgentExecutor (unless already wrapped) - AgentExecutor produces AgentExecutorResponse; _ResponseToConversation converts this to list[ChatMessage] - Non-agent executors must define a handler that consumes `list[ChatMessage]` and sends back the updated `list[ChatMessage]` via their workflow context @@ -153,6 +154,9 @@ def register_participants( "Cannot mix .participants([...]) and .register_participants() in the same builder instance." ) + if self._participant_factories: + raise ValueError("register_participants() has already been called on this builder instance.") + if not participant_factories: raise ValueError("participant_factories cannot be empty") @@ -170,6 +174,9 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Seq "Cannot mix .participants([...]) and .register_participants() in the same builder instance." ) + if self._participants: + raise ValueError("participants() has already been called on this builder instance.") + if not participants: raise ValueError("participants cannot be empty") @@ -252,7 +259,7 @@ def build(self) -> Workflow: if not self._participants and not self._participant_factories: raise ValueError( "No participants or participant factories provided to the builder. " - "Use .participants([...]) or .ss([...])." + "Use .participants([...]) or .register_participants([...])." ) if self._participants and self._participant_factories: @@ -273,6 +280,8 @@ def build(self) -> Workflow: participants: list[Executor | AgentProtocol] = [] if self._participant_factories: + # Resolve the participant factories now. This doesn't break the factory pattern + # since the Sequential builder still creates new instances per workflow build. for factory in self._participant_factories: p = factory() participants.append(p) diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index caa60fbef6..7b446926fc 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -13,7 +13,7 @@ from ..observability import OtelAttr, capture_exception, create_workflow_span from ._agent import WorkflowAgent from ._checkpoint import CheckpointStorage -from ._const import DEFAULT_MAX_ITERATIONS +from ._const import DEFAULT_MAX_ITERATIONS, WORKFLOW_RUN_KWARGS_KEY from ._edge import ( EdgeGroup, FanOutEdgeGroup, @@ -291,6 +291,7 @@ async def _run_workflow_with_tracing( initial_executor_fn: Callable[[], Awaitable[None]] | None = None, reset_context: bool = True, streaming: bool = False, + run_kwargs: dict[str, Any] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Private method to run workflow with proper tracing. @@ -301,6 +302,7 @@ async def _run_workflow_with_tracing( initial_executor_fn: Optional function to execute initial executor reset_context: Whether to reset the context for a new run streaming: Whether to enable streaming mode for agents + run_kwargs: Optional kwargs to store in SharedState for agent invocations Yields: WorkflowEvent: The events generated during the workflow execution. @@ -335,6 +337,10 @@ async def _run_workflow_with_tracing( self._runner.context.reset_for_new_run() await self._shared_state.clear() + # Store run kwargs in SharedState so executors can access them + # Always store (even empty dict) so retrieval is deterministic + await self._shared_state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs or {}) + # Set streaming mode after reset self._runner_context.set_streaming(streaming) @@ -442,6 +448,7 @@ async def run_stream( *, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, + **kwargs: Any, ) -> AsyncIterable[WorkflowEvent]: """Run the workflow and stream events. @@ -457,6 +464,9 @@ async def run_stream( - With checkpoint_id: Used to load and restore the specified checkpoint - Without checkpoint_id: Enables checkpointing for this run, overriding build-time configuration + **kwargs: Additional keyword arguments to pass through to agent invocations. + These are stored in SharedState and accessible in @ai_function tools + via the **kwargs parameter. Yields: WorkflowEvent: Events generated during workflow execution. @@ -475,6 +485,17 @@ async def run_stream( async for event in workflow.run_stream("start message"): process(event) + With custom context for ai_functions: + + .. code-block:: python + + async for event in workflow.run_stream( + "analyze data", + custom_data={"endpoint": "https://api.example.com"}, + user_token={"user": "alice"}, + ): + process(event) + Enable checkpointing at runtime: .. code-block:: python @@ -524,6 +545,7 @@ async def run_stream( ), reset_context=reset_context, streaming=True, + run_kwargs=kwargs if kwargs else None, ): yield event finally: @@ -559,6 +581,7 @@ async def run( checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, + **kwargs: Any, ) -> WorkflowRunResult: """Run the workflow to completion and return all events. @@ -575,6 +598,9 @@ async def run( - Without checkpoint_id: Enables checkpointing for this run, overriding build-time configuration include_status_events: Whether to include WorkflowStatusEvent instances in the result list. + **kwargs: Additional keyword arguments to pass through to agent invocations. + These are stored in SharedState and accessible in @ai_function tools + via the **kwargs parameter. Returns: A WorkflowRunResult instance containing events generated during workflow execution. @@ -593,6 +619,16 @@ async def run( result = await workflow.run("start message") outputs = result.get_outputs() + With custom context for ai_functions: + + .. code-block:: python + + result = await workflow.run( + "analyze data", + custom_data={"endpoint": "https://api.example.com"}, + user_token={"user": "alice"}, + ) + Enable checkpointing at runtime: .. code-block:: python @@ -637,6 +673,7 @@ async def run( self._execute_with_message_or_checkpoint, message, checkpoint_id, checkpoint_storage ), reset_context=reset_context, + run_kwargs=kwargs if kwargs else None, ) ] finally: diff --git a/python/packages/core/agent_framework/_workflows/_workflow_builder.py b/python/packages/core/agent_framework/_workflows/_workflow_builder.py index 26cd0213e4..60c959823f 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_builder.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_builder.py @@ -302,7 +302,8 @@ async def process(self, text: str, ctx: WorkflowContext[Never, str]) -> None: If multiple names are provided, the same factory function will be registered under each name. - ...code-block:: python + .. code-block:: python + from agent_framework import WorkflowBuilder, Executor, WorkflowContext, handler @@ -315,7 +316,7 @@ async def log(self, message: str, ctx: WorkflowContext) -> None: # Register the same executor factory under multiple names workflow = ( WorkflowBuilder() - .register_executor(lambda: CustomExecutor(id="logger"), name=["ExecutorA", "ExecutorB"]) + .register_executor(lambda: LoggerExecutor(id="logger"), name=["ExecutorA", "ExecutorB"]) .set_start_executor("ExecutorA") .add_edge("ExecutorA", "ExecutorB") .build() @@ -374,7 +375,7 @@ def register_agent( ) """ if name in self._executor_registry: - raise ValueError(f"An executor factory with the name '{name}' is already registered.") + raise ValueError(f"An agent factory with the name '{name}' is already registered.") def wrapped_factory() -> AgentExecutor: agent = factory_func() @@ -456,7 +457,12 @@ def add_edge( source: The source executor or registered name of the source factory for the edge. target: The target executor or registered name of the target factory for the edge. condition: An optional condition function that determines whether the edge - should be traversed based on the message type. + should be traversed based on the message. + + Note: If instances are provided for both source and target, they will be shared across + all workflow instances created from the built Workflow. To avoid this, consider + registering the executors and agents using `register_executor` and `register_agent` + and referencing them by factory name for lazy initialization instead. Returns: Self: The WorkflowBuilder instance for method chaining. @@ -505,17 +511,13 @@ def only_large_numbers(msg: int) -> bool: .build() ) """ - if not isinstance(source, str) or not isinstance(target, str): - logger.warning( - "Adding an edge with Executor or AgentProtocol instances directly is not recommended, " - "because workflow instances created from the builder will share the same executor/agent instances. " - "Consider using a registered name for lazy initialization instead." - ) - if (isinstance(source, str) and not isinstance(target, str)) or ( not isinstance(source, str) and isinstance(target, str) ): - raise ValueError("Both source and target must be either names (str) or Executor/AgentProtocol instances.") + raise ValueError( + "Both source and target must be either registered factory names (str) " + "or Executor/AgentProtocol instances." + ) if isinstance(source, str) and isinstance(target, str): # Both are names; defer resolution to build time @@ -547,6 +549,11 @@ def add_fan_out_edges( Returns: Self: The WorkflowBuilder instance for method chaining. + Note: If instances are provided for source and targets, they will be shared across + all workflow instances created from the built Workflow. To avoid this, consider + registering the executors and agents using `register_executor` and `register_agent` + and referencing them by factory name for lazy initialization instead. + Example: .. code-block:: python @@ -583,17 +590,13 @@ async def validate(self, data: str, ctx: WorkflowContext) -> None: .build() ) """ - if not isinstance(source, str) or any(not isinstance(t, str) for t in targets): - logger.warning( - "Adding fan-out edges with Executor or AgentProtocol instances directly is not recommended, " - "because workflow instances created from the builder will share the same executor/agent instances. " - "Consider using registered names for lazy initialization instead." - ) - if (isinstance(source, str) and not all(isinstance(t, str) for t in targets)) or ( not isinstance(source, str) and any(isinstance(t, str) for t in targets) ): - raise ValueError("Both source and targets must be either names (str) or Executor/AgentProtocol instances.") + raise ValueError( + "Both source and targets must be either registered factory names (str) " + "or Executor/AgentProtocol instances." + ) if isinstance(source, str) and all(isinstance(t, str) for t in targets): # Both are names; defer resolution to build time @@ -624,7 +627,7 @@ def add_switch_case_edge_group( Each condition function will be evaluated in order, and the first one that returns True will determine which target executor receives the message. - The last case (the default case) will receive messages that fall through all conditions + The default case (if provided) will receive messages that fall through all conditions (i.e., no condition matched). Args: @@ -634,6 +637,11 @@ def add_switch_case_edge_group( Returns: Self: The WorkflowBuilder instance for method chaining. + Note: If instances are provided for source and case targets, they will be shared across + all workflow instances created from the built Workflow. To avoid this, consider + registering the executors and agents using `register_executor` and `register_agent` + and referencing them by factory name for lazy initialization instead. + Example: .. code-block:: python @@ -681,18 +689,12 @@ async def handle(self, result: Result, ctx: WorkflowContext) -> None: .build() ) """ - if not isinstance(source, str) or not all(isinstance(case.target, str) for case in cases): - logger.warning( - "Adding a switch-case edge group with Executor or AgentProtocol instances directly is not recommended, " - "because workflow instances created from the builder will share the same executor/agent instance. " - "Consider using a registered name for lazy initialization instead." - ) - if (isinstance(source, str) and not all(isinstance(case.target, str) for case in cases)) or ( not isinstance(source, str) and any(isinstance(case.target, str) for case in cases) ): raise ValueError( - "Both source and case targets must be either names (str) or Executor/AgentProtocol instances." + "Both source and case targets must be either registered factory names (str) " + "or Executor/AgentProtocol instances." ) if isinstance(source, str) and all(isinstance(case.target, str) for case in cases): @@ -736,12 +738,16 @@ def add_multi_selection_edge_group( source: The source executor or registered name of the source factory for the edge group. targets: A list of target executors or registered names of the target factories for the edges. selection_func: A function that selects target executors for messages. - Takes (message, list[executor_id or registered target names]) and - returns list[executor_id or registered target names]. + Takes (message, list[executor_id]) and returns list[executor_id]. Returns: Self: The WorkflowBuilder instance for method chaining. + Note: If instances are provided for source and targets, they will be shared across + all workflow instances created from the built Workflow. To avoid this, consider + registering the executors and agents using `register_executor` and `register_agent` + and referencing them by factory name for lazy initialization instead. + Example: .. code-block:: python @@ -795,17 +801,13 @@ def select_workers(task: Task, available: list[str]) -> list[str]: .build() ) """ - if not isinstance(source, str) or any(not isinstance(t, str) for t in targets): - logger.warning( - "Adding fan-out edges with Executor or AgentProtocol instances directly is not recommended, " - "because workflow instances created from the builder will share the same executor/agent instances. " - "Consider using registered names for lazy initialization instead." - ) - if (isinstance(source, str) and not all(isinstance(t, str) for t in targets)) or ( not isinstance(source, str) and any(isinstance(t, str) for t in targets) ): - raise ValueError("Both source and targets must be either names (str) or Executor/AgentProtocol instances.") + raise ValueError( + "Both source and targets must be either registered factory names (str) " + "or Executor/AgentProtocol instances." + ) if isinstance(source, str) and all(isinstance(t, str) for t in targets): # Both are names; defer resolution to build time @@ -848,6 +850,11 @@ def add_fan_in_edges( Returns: Self: The WorkflowBuilder instance for method chaining. + Note: If instances are provided for sources and target, they will be shared across + all workflow instances created from the built Workflow. To avoid this, consider + registering the executors and agents using `register_executor` and `register_agent` + and referencing them by factory name for lazy initialization instead. + Example: .. code-block:: python @@ -879,17 +886,13 @@ async def aggregate(self, results: list[str], ctx: WorkflowContext[Never, str]) .build() ) """ - if not all(isinstance(s, str) for s in sources) or not isinstance(target, str): - logger.warning( - "Adding fan-in edges with Executor or AgentProtocol instances directly is not recommended, " - "because workflow instances created from the builder will share the same executor/agent instances. " - "Consider using registered names for lazy initialization instead." - ) - if (all(isinstance(s, str) for s in sources) and not isinstance(target, str)) or ( not all(isinstance(s, str) for s in sources) and isinstance(target, str) ): - raise ValueError("Both sources and target must be either names (str) or Executor/AgentProtocol instances.") + raise ValueError( + "Both sources and target must be either registered factory names (str) " + "or Executor/AgentProtocol instances." + ) if all(isinstance(s, str) for s in sources) and isinstance(target, str): # Both are names; defer resolution to build time @@ -911,7 +914,7 @@ def add_chain(self, executors: Sequence[Executor | AgentProtocol | str]) -> Self The output of each executor in the chain will be sent to the next executor in the chain. The input types of each executor must be compatible with the output types of the previous executor. - Circles in the chain are not allowed, meaning the chain cannot have two executors with the same ID. + Cycles in the chain are not allowed, meaning an executor cannot appear more than once in the chain. Args: executors: A list of executors or registered names of the executor factories to chain together. @@ -919,6 +922,11 @@ def add_chain(self, executors: Sequence[Executor | AgentProtocol | str]) -> Self Returns: Self: The WorkflowBuilder instance for method chaining. + Note: If executor instances are provided, they will be shared across all workflow instances created + from the built Workflow. To avoid this, consider registering the executors and agents using + `register_executor` and `register_agent` and referencing them by factory name for lazy + initialization instead. + Example: .. code-block:: python @@ -958,16 +966,10 @@ async def process(self, text: str, ctx: WorkflowContext[Never, str]) -> None: if len(executors) < 2: raise ValueError("At least two executors are required to form a chain.") - if not all(isinstance(e, str) for e in executors): - logger.warning( - "Adding a chain with Executor or AgentProtocol instances directly is not recommended, " - "because workflow instances created from the builder will share the same executor/agent instances. " - "Consider using registered names for lazy initialization instead." - ) - if not all(isinstance(e, str) for e in executors) and any(isinstance(e, str) for e in executors): raise ValueError( - "All executors in the chain must be either names (str) or Executor/AgentProtocol instances." + "All executors in the chain must be either registered factory names (str) " + "or Executor/AgentProtocol instances." ) if all(isinstance(e, str) for e in executors): @@ -976,7 +978,7 @@ async def process(self, text: str, ctx: WorkflowContext[Never, str]) -> None: self.add_edge(executors[i], executors[i + 1]) return self - # Both are Executor/AgentProtocol instances; wrap and add now + # All are Executor/AgentProtocol instances; wrap and add now # Wrap each candidate first to ensure stable IDs before adding edges wrapped: list[Executor] = [self._maybe_wrap_agent(e) for e in executors] # type: ignore[arg-type] for i in range(len(wrapped) - 1): @@ -1148,21 +1150,29 @@ def _resolve_edge_registry(self) -> tuple[Executor, list[Executor], list[EdgeGro if isinstance(self._start_executor, Executor): start_executor = self._start_executor - executors: dict[str, Executor] = {} + # Maps registered factory names to created executor instances for edge resolution + factory_name_to_instance: dict[str, Executor] = {} + # Maps executor IDs to created executor instances to prevent duplicates + executor_id_to_instance: dict[str, Executor] = {} deferred_edge_groups: list[EdgeGroup] = [] for name, exec_factory in self._executor_registry.items(): instance = exec_factory() + if instance.id in executor_id_to_instance: + raise ValueError(f"Executor with ID '{instance.id}' has already been created.") + executor_id_to_instance[instance.id] = instance + if isinstance(self._start_executor, str) and name == self._start_executor: start_executor = instance + # All executors will get their own internal edge group for receiving system messages deferred_edge_groups.append(InternalEdgeGroup(instance.id)) # type: ignore[call-arg] - executors[name] = instance + factory_name_to_instance[name] = instance def _get_executor(name: str) -> Executor: """Helper to get executor by the registered name. Raises if not found.""" - if name not in executors: - raise ValueError(f"Executor with name '{name}' has not been registered.") - return executors[name] + if name not in factory_name_to_instance: + raise ValueError(f"Factory '{name}' has not been registered.") + return factory_name_to_instance[name] for registration in self._edge_registry: match registration: @@ -1179,7 +1189,7 @@ def _get_executor(name: str) -> Executor: cases_converted: list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault] = [] for case in cases: if not isinstance(case.target, str): - raise ValueError("Switch case target must be a registered executor name (str) if deferred.") + raise ValueError("Switch case target must be a registered factory name (str) if deferred.") target_exec = _get_executor(case.target) if isinstance(case, Default): cases_converted.append(SwitchCaseEdgeGroupDefault(target_id=target_exec.id)) @@ -1201,7 +1211,7 @@ def _get_executor(name: str) -> Executor: if start_executor is None: raise ValueError("Failed to resolve starting executor from registered factories.") - return start_executor, list(executors.values()), deferred_edge_groups + return start_executor, list(executor_id_to_instance.values()), deferred_edge_groups def build(self) -> Workflow: """Build and return the constructed workflow. diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index 544c0fdf5b..a8bfec0427 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -21,7 +21,7 @@ use_function_invocation, ) from agent_framework.exceptions import ServiceInitializationError -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from agent_framework.openai._chat_client import OpenAIBaseChatClient from ._shared import ( @@ -41,7 +41,7 @@ @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AzureOpenAIChatClient(AzureOpenAIConfigMixin, OpenAIBaseChatClient): """Azure OpenAI Chat completion class.""" diff --git a/python/packages/core/agent_framework/azure/_responses_client.py b/python/packages/core/agent_framework/azure/_responses_client.py index 1d88d51688..3f6140eeeb 100644 --- a/python/packages/core/agent_framework/azure/_responses_client.py +++ b/python/packages/core/agent_framework/azure/_responses_client.py @@ -10,7 +10,7 @@ from agent_framework import use_chat_middleware, use_function_invocation from agent_framework.exceptions import ServiceInitializationError -from agent_framework.observability import use_observability +from agent_framework.observability import use_instrumentation from agent_framework.openai._responses_client import OpenAIBaseResponsesClient from ._shared import ( @@ -22,7 +22,7 @@ @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class AzureOpenAIResponsesClient(AzureOpenAIConfigMixin, OpenAIBaseResponsesClient): """Azure Responses completion class.""" diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index f3e1d9bd68..38fca796c1 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -3,15 +3,19 @@ import contextlib import json import logging +import os from collections.abc import AsyncIterable, Awaitable, Callable, Generator, Mapping from enum import Enum from functools import wraps from time import perf_counter, time_ns from typing import TYPE_CHECKING, Any, ClassVar, Final, TypeVar +from dotenv import load_dotenv from opentelemetry import metrics, trace +from opentelemetry.sdk.resources import Resource +from opentelemetry.semconv.attributes import service_attributes from opentelemetry.semconv_ai import GenAISystem, Meters, SpanAttributes -from pydantic import BaseModel, PrivateAttr +from pydantic import PrivateAttr from . import __version__ as version_info from ._logging import get_logger @@ -19,10 +23,9 @@ from .exceptions import AgentInitializationError, ChatClientInitializationError if TYPE_CHECKING: # pragma: no cover - from azure.core.credentials import TokenCredential from opentelemetry.sdk._logs.export import LogRecordExporter from opentelemetry.sdk.metrics.export import MetricExporter - from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.metrics.view import View from opentelemetry.sdk.trace.export import SpanExporter from opentelemetry.trace import Tracer from opentelemetry.util._decorator import _AgnosticContextManager # type: ignore[reportPrivateUsage] @@ -44,11 +47,14 @@ __all__ = [ "OBSERVABILITY_SETTINGS", "OtelAttr", + "configure_otel_providers", + "create_metric_views", + "create_resource", + "enable_instrumentation", "get_meter", "get_tracer", - "setup_observability", - "use_agent_observability", - "use_observability", + "use_agent_instrumentation", + "use_instrumentation", ] @@ -259,89 +265,293 @@ def __str__(self) -> str: # region Telemetry utils -def _get_otlp_exporters(endpoints: list[str]) -> list["LogRecordExporter | SpanExporter | MetricExporter"]: - """Create standard OTLP Exporters for the supplied endpoints.""" - from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter - from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +# Parse headers helper +def _parse_headers(header_str: str) -> dict[str, str]: + """Parse header string like 'key1=value1,key2=value2' into dict.""" + headers: dict[str, str] = {} + if not header_str: + return headers + for pair in header_str.split(","): + if "=" in pair: + key, value = pair.split("=", 1) + headers[key.strip()] = value.strip() + return headers + + +def _create_otlp_exporters( + endpoint: str | None = None, + protocol: str = "grpc", + headers: dict[str, str] | None = None, + traces_endpoint: str | None = None, + traces_headers: dict[str, str] | None = None, + metrics_endpoint: str | None = None, + metrics_headers: dict[str, str] | None = None, + logs_endpoint: str | None = None, + logs_headers: dict[str, str] | None = None, +) -> list["LogRecordExporter | SpanExporter | MetricExporter"]: + """Create OTLP exporters for a given endpoint and protocol. + + Args: + endpoint: The OTLP endpoint URL (used for all exporters if individual endpoints not specified). + protocol: The protocol to use ("grpc" or "http"). Default is "grpc". + headers: Optional headers to include in requests (used for all exporters if individual headers not specified). + traces_endpoint: Optional specific endpoint for traces. Overrides endpoint parameter. + traces_headers: Optional specific headers for traces. Overrides headers parameter. + metrics_endpoint: Optional specific endpoint for metrics. Overrides endpoint parameter. + metrics_headers: Optional specific headers for metrics. Overrides headers parameter. + logs_endpoint: Optional specific endpoint for logs. Overrides endpoint parameter. + logs_headers: Optional specific headers for logs. Overrides headers parameter. + + Returns: + List containing OTLPLogExporter, OTLPSpanExporter, and OTLPMetricExporter. + + Raises: + ImportError: If the required OTLP exporter package is not installed. + """ + # Determine actual endpoints and headers to use + actual_traces_endpoint = traces_endpoint or endpoint + actual_metrics_endpoint = metrics_endpoint or endpoint + actual_logs_endpoint = logs_endpoint or endpoint + actual_traces_headers = traces_headers or headers + actual_metrics_headers = metrics_headers or headers + actual_logs_headers = logs_headers or headers exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = [] - for endpoint in endpoints: - exporters.append(OTLPLogExporter(endpoint=endpoint)) - exporters.append(OTLPSpanExporter(endpoint=endpoint)) - exporters.append(OTLPMetricExporter(endpoint=endpoint)) - return exporters + if not actual_logs_endpoint and not actual_traces_endpoint and not actual_metrics_endpoint: + return exporters + if protocol in ("grpc", "http/protobuf"): + # Import all gRPC exporters + try: + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter as GRPCLogExporter + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( + OTLPMetricExporter as GRPCMetricExporter, + ) + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter + except ImportError as exc: + raise ImportError( + "opentelemetry-exporter-otlp-proto-grpc is required for OTLP gRPC exporters. " + "Install it with: pip install opentelemetry-exporter-otlp-proto-grpc" + ) from exc + + if actual_logs_endpoint: + exporters.append( + GRPCLogExporter( + endpoint=actual_logs_endpoint, + headers=actual_logs_headers if actual_logs_headers else None, + ) + ) + if actual_traces_endpoint: + exporters.append( + GRPCSpanExporter( + endpoint=actual_traces_endpoint, + headers=actual_traces_headers if actual_traces_headers else None, + ) + ) + if actual_metrics_endpoint: + exporters.append( + GRPCMetricExporter( + endpoint=actual_metrics_endpoint, + headers=actual_metrics_headers if actual_metrics_headers else None, + ) + ) -def _get_azure_monitor_exporters( - connection_strings: list[str], - credential: "TokenCredential | None" = None, -) -> list["LogRecordExporter | SpanExporter | MetricExporter"]: - """Create Azure Monitor Exporters, based on the connection strings and optionally the credential.""" - try: - from azure.monitor.opentelemetry.exporter import ( - AzureMonitorLogExporter, - AzureMonitorMetricExporter, - AzureMonitorTraceExporter, - ) - except ImportError as e: - raise ImportError( - "azure-monitor-opentelemetry-exporter is required for Azure Monitor exporters. " - "Install it with: pip install azure-monitor-opentelemetry-exporter>=1.0.0b41" - ) from e + elif protocol == "http": + # Import all HTTP exporters + try: + from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter as HTTPLogExporter + from opentelemetry.exporter.otlp.proto.http.metric_exporter import ( + OTLPMetricExporter as HTTPMetricExporter, + ) + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter + except ImportError as exc: + raise ImportError( + "opentelemetry-exporter-otlp-proto-http is required for OTLP HTTP exporters. " + "Install it with: pip install opentelemetry-exporter-otlp-proto-http" + ) from exc + + if actual_logs_endpoint: + exporters.append( + HTTPLogExporter( + endpoint=actual_logs_endpoint, + headers=actual_logs_headers if actual_logs_headers else None, + ) + ) + if actual_traces_endpoint: + exporters.append( + HTTPSpanExporter( + endpoint=actual_traces_endpoint, + headers=actual_traces_headers if actual_traces_headers else None, + ) + ) + if actual_metrics_endpoint: + exporters.append( + HTTPMetricExporter( + endpoint=actual_metrics_endpoint, + headers=actual_metrics_headers if actual_metrics_headers else None, + ) + ) - exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = [] - for conn_string in connection_strings: - exporters.append(AzureMonitorLogExporter(connection_string=conn_string, credential=credential)) - exporters.append(AzureMonitorTraceExporter(connection_string=conn_string, credential=credential)) - exporters.append(AzureMonitorMetricExporter(connection_string=conn_string, credential=credential)) return exporters -def get_exporters( - otlp_endpoints: list[str] | None = None, - connection_strings: list[str] | None = None, - credential: "TokenCredential | None" = None, +def _get_exporters_from_env( + env_file_path: str | None = None, + env_file_encoding: str | None = None, ) -> list["LogRecordExporter | SpanExporter | MetricExporter"]: - """Add additional exporters to the existing configuration. + """Parse OpenTelemetry environment variables and create exporters. + + This function reads standard OpenTelemetry environment variables to configure + OTLP exporters for traces, logs, and metrics. + + The following environment variables are supported: + - OTEL_EXPORTER_OTLP_ENDPOINT: Base endpoint for all signals + - OTEL_EXPORTER_OTLP_TRACES_ENDPOINT: Endpoint specifically for traces + - OTEL_EXPORTER_OTLP_METRICS_ENDPOINT: Endpoint specifically for metrics + - OTEL_EXPORTER_OTLP_LOGS_ENDPOINT: Endpoint specifically for logs + - OTEL_EXPORTER_OTLP_PROTOCOL: Protocol to use (grpc, http/protobuf) + - OTEL_EXPORTER_OTLP_HEADERS: Headers for all signals + - OTEL_EXPORTER_OTLP_TRACES_HEADERS: Headers specifically for traces + - OTEL_EXPORTER_OTLP_METRICS_HEADERS: Headers specifically for metrics + - OTEL_EXPORTER_OTLP_LOGS_HEADERS: Headers specifically for logs - If you supply exporters, those will be added to the relevant providers directly. - If you supply endpoints or connection strings, new exporters will be created and added. - OTLP_endpoints will be used to create a `OTLPLogExporter`, `OTLPMetricExporter` and `OTLPSpanExporter` - Connection_strings will be used to create AzureMonitorExporters. + Args: + env_file_path: Path to a .env file to load environment variables from. + Default is None, which loads from '.env' if present. + env_file_encoding: Encoding to use when reading the .env file. + Default is None, which uses the system default encoding. - If a endpoint or connection string is already configured, through the environment variables, it will be skipped. - If you call this method twice with the same additional endpoint or connection string, it will be added twice. + Returns: + List of configured exporters (empty if no relevant env vars are set). - Args: - otlp_endpoints: A list of OpenTelemetry Protocol (OTLP) endpoints. Default is None. - connection_strings: A list of Azure Monitor connection strings. Default is None. - credential: The credential to use for Azure Monitor Entra ID authentication. Default is None. + References: + - https://opentelemetry.io/docs/languages/sdk-configuration/general/ + - https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter/ """ - new_exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = [] - if otlp_endpoints: - new_exporters.extend(_get_otlp_exporters(endpoints=otlp_endpoints)) - - if connection_strings: - new_exporters.extend( - _get_azure_monitor_exporters( - connection_strings=connection_strings, - credential=credential, + # Load environment variables from .env file if present + load_dotenv(dotenv_path=env_file_path, encoding=env_file_encoding) + + # Get base endpoint + base_endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT") + + # Get signal-specific endpoints (these override base endpoint) + traces_endpoint = os.getenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT") or base_endpoint + metrics_endpoint = os.getenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT") or base_endpoint + logs_endpoint = os.getenv("OTEL_EXPORTER_OTLP_LOGS_ENDPOINT") or base_endpoint + + # Get protocol (default is grpc) + protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc").lower() + + # Get base headers + base_headers_str = os.getenv("OTEL_EXPORTER_OTLP_HEADERS", "") + base_headers = _parse_headers(base_headers_str) + + # Get signal-specific headers (these merge with base headers) + traces_headers_str = os.getenv("OTEL_EXPORTER_OTLP_TRACES_HEADERS", "") + metrics_headers_str = os.getenv("OTEL_EXPORTER_OTLP_METRICS_HEADERS", "") + logs_headers_str = os.getenv("OTEL_EXPORTER_OTLP_LOGS_HEADERS", "") + + traces_headers = {**base_headers, **_parse_headers(traces_headers_str)} + metrics_headers = {**base_headers, **_parse_headers(metrics_headers_str)} + logs_headers = {**base_headers, **_parse_headers(logs_headers_str)} + + # Create exporters using helper function + return _create_otlp_exporters( + protocol=protocol, + traces_endpoint=traces_endpoint, + traces_headers=traces_headers if traces_headers else None, + metrics_endpoint=metrics_endpoint, + metrics_headers=metrics_headers if metrics_headers else None, + logs_endpoint=logs_endpoint, + logs_headers=logs_headers if logs_headers else None, + ) + + +def create_resource( + service_name: str | None = None, + service_version: str | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + **attributes: Any, +) -> "Resource": + """Create an OpenTelemetry Resource from environment variables and parameters. + + This function reads standard OpenTelemetry environment variables to configure + the resource, which identifies your service in telemetry backends. + + The following environment variables are read: + - OTEL_SERVICE_NAME: The name of the service (defaults to "agent_framework") + - OTEL_SERVICE_VERSION: The version of the service (defaults to package version) + - OTEL_RESOURCE_ATTRIBUTES: Additional resource attributes as key=value pairs + + Args: + service_name: Override the service name. If not provided, reads from + OTEL_SERVICE_NAME environment variable or defaults to "agent_framework". + service_version: Override the service version. If not provided, reads from + OTEL_SERVICE_VERSION environment variable or defaults to the package version. + env_file_path: Path to a .env file to load environment variables from. + Default is None, which loads from '.env' if present. + env_file_encoding: Encoding to use when reading the .env file. + Default is None, which uses the system default encoding. + **attributes: Additional resource attributes to include. These will be merged + with attributes from OTEL_RESOURCE_ATTRIBUTES environment variable. + + Returns: + A configured OpenTelemetry Resource instance. + + Examples: + .. code-block:: python + + from agent_framework.observability import create_resource + + # Use defaults from environment variables + resource = create_resource() + + # Override service name + resource = create_resource(service_name="my_service") + + # Add custom attributes + resource = create_resource( + service_name="my_service", service_version="1.0.0", deployment_environment="production" ) - ) - return new_exporters + # Load from custom .env file + resource = create_resource(env_file_path="config/.env") + """ + # Load environment variables from .env file if present + load_dotenv(dotenv_path=env_file_path, encoding=env_file_encoding) + + # Start with provided attributes + resource_attributes: dict[str, Any] = dict(attributes) + + # Set service name + if service_name is None: + service_name = os.getenv("OTEL_SERVICE_NAME", "agent_framework") + resource_attributes[service_attributes.SERVICE_NAME] = service_name -def _create_resource() -> "Resource": - import os + # Set service version + if service_version is None: + service_version = os.getenv("OTEL_SERVICE_VERSION", version_info) + resource_attributes[service_attributes.SERVICE_VERSION] = service_version - from opentelemetry.sdk.resources import Resource - from opentelemetry.semconv.attributes import service_attributes + # Parse OTEL_RESOURCE_ATTRIBUTES environment variable + # Format: key1=value1,key2=value2 + if resource_attrs_env := os.getenv("OTEL_RESOURCE_ATTRIBUTES"): + resource_attributes.update(_parse_headers(resource_attrs_env)) + return Resource.create(resource_attributes) - service_name = os.getenv("OTEL_SERVICE_NAME", "agent_framework") - return Resource.create({service_attributes.SERVICE_NAME: service_name}) +def create_metric_views() -> list["View"]: + """Create the default OpenTelemetry metric views for Agent Framework.""" + from opentelemetry.sdk.metrics.view import DropAggregation, View + + return [ + # Dropping all enable_instrumentation names except for those starting with "agent_framework" + View(instrument_name="agent_framework*"), + View(instrument_name="gen_ai*"), + View(instrument_name="*", aggregation=DropAggregation()), + ] class ObservabilitySettings(AFBaseSettings): @@ -357,14 +567,12 @@ class ObservabilitySettings(AFBaseSettings): Sensitive events should only be enabled on test and development environments. Keyword Args: - enable_otel: Enable OpenTelemetry diagnostics. Default is False. - Can be set via environment variable ENABLE_OTEL. + enable_instrumentation: Enable OpenTelemetry diagnostics. Default is False. + Can be set via environment variable ENABLE_INSTRUMENTATION. enable_sensitive_data: Enable OpenTelemetry sensitive events. Default is False. Can be set via environment variable ENABLE_SENSITIVE_DATA. - applicationinsights_connection_string: The Azure Monitor connection string. Default is None. - Can be set via environment variable APPLICATIONINSIGHTS_CONNECTION_STRING. - otlp_endpoint: The OpenTelemetry Protocol (OTLP) endpoint. Default is None. - Can be set via environment variable OTLP_ENDPOINT. + enable_console_exporters: Enable console exporters for traces, logs, and metrics. + Default is False. Can be set via environment variable ENABLE_CONSOLE_EXPORTERS. vs_code_extension_port: The port the AI Toolkit or Azure AI Foundry VS Code extensions are listening on. Default is None. Can be set via environment variable VS_CODE_EXTENSION_PORT. @@ -375,33 +583,39 @@ class ObservabilitySettings(AFBaseSettings): from agent_framework import ObservabilitySettings # Using environment variables - # Set ENABLE_OTEL=true - # Set APPLICATIONINSIGHTS_CONNECTION_STRING=InstrumentationKey=... + # Set ENABLE_INSTRUMENTATION=true + # Set ENABLE_CONSOLE_EXPORTERS=true settings = ObservabilitySettings() # Or passing parameters directly - settings = ObservabilitySettings( - enable_otel=True, applicationinsights_connection_string="InstrumentationKey=..." - ) + settings = ObservabilitySettings(enable_instrumentation=True, enable_console_exporters=True) """ env_prefix: ClassVar[str] = "" - enable_otel: bool = False + enable_instrumentation: bool = False enable_sensitive_data: bool = False - applicationinsights_connection_string: str | list[str] | None = None - otlp_endpoint: str | list[str] | None = None + enable_console_exporters: bool = False vs_code_extension_port: int | None = None - _resource: "Resource" = PrivateAttr(default_factory=_create_resource) + _resource: "Resource" = PrivateAttr() _executed_setup: bool = PrivateAttr(default=False) + def __init__(self, **kwargs: Any) -> None: + """Initialize the settings and create the resource.""" + super().__init__(**kwargs) + # Create resource with env file settings + self._resource = create_resource( + env_file_path=self.env_file_path, + env_file_encoding=self.env_file_encoding, + ) + @property def ENABLED(self) -> bool: """Check if model diagnostics are enabled. Model diagnostics are enabled if either diagnostic is enabled or diagnostic with sensitive events is enabled. """ - return self.enable_otel or self.enable_sensitive_data + return self.enable_instrumentation @property def SENSITIVE_DATA_ENABLED(self) -> bool: @@ -409,27 +623,18 @@ def SENSITIVE_DATA_ENABLED(self) -> bool: Sensitive events are enabled if the diagnostic with sensitive events is enabled. """ - return self.enable_sensitive_data + return self.enable_instrumentation and self.enable_sensitive_data @property def is_setup(self) -> bool: """Check if the setup has been executed.""" return self._executed_setup - @property - def resource(self) -> "Resource": - """Get the resource.""" - return self._resource - - @resource.setter - def resource(self, value: "Resource") -> None: - """Set the resource.""" - self._resource = value - def _configure( self, - credential: "TokenCredential | None" = None, + *, additional_exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] | None = None, + views: list["View"] | None = None, ) -> None: """Configure application-wide observability based on the settings. @@ -438,120 +643,102 @@ def _configure( will have no effect. Args: - credential: The credential to use for Azure Monitor Entra ID authentication. Default is None. additional_exporters: A list of additional exporters to add to the configuration. Default is None. + views: Optional list of OpenTelemetry views for metrics. Default is None. """ if not self.ENABLED or self._executed_setup: return - exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = additional_exporters or [] - if self.otlp_endpoint: - exporters.extend( - _get_otlp_exporters( - self.otlp_endpoint if isinstance(self.otlp_endpoint, list) else [self.otlp_endpoint] - ) - ) - if self.applicationinsights_connection_string: - exporters.extend( - _get_azure_monitor_exporters( - connection_strings=( - self.applicationinsights_connection_string - if isinstance(self.applicationinsights_connection_string, list) - else [self.applicationinsights_connection_string] - ), - credential=credential, - ) + exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = [] + + # 1. Add exporters from standard OTEL environment variables + exporters.extend( + _get_exporters_from_env( + env_file_path=self.env_file_path, + env_file_encoding=self.env_file_encoding, ) - self._configure_providers(exporters) - self._executed_setup = True + ) - def check_endpoint_already_configured(self, otlp_endpoint: str) -> bool: - """Check if the endpoint is already configured. + # 2. Add passed-in exporters + if additional_exporters: + exporters.extend(additional_exporters) - Returns: - True if the endpoint is already configured, False otherwise. - """ - if not self.otlp_endpoint: - return False - return otlp_endpoint in (self.otlp_endpoint if isinstance(self.otlp_endpoint, list) else [self.otlp_endpoint]) + # 3. Add console exporters if explicitly enabled + if self.enable_console_exporters: + from opentelemetry.sdk._logs.export import ConsoleLogRecordExporter + from opentelemetry.sdk.metrics.export import ConsoleMetricExporter + from opentelemetry.sdk.trace.export import ConsoleSpanExporter - def check_connection_string_already_configured(self, connection_string: str) -> bool: - """Check if the connection string is already configured. + exporters.extend([ConsoleSpanExporter(), ConsoleLogRecordExporter(), ConsoleMetricExporter()]) - Returns: - True if the connection string is already configured, False otherwise. - """ - if not self.applicationinsights_connection_string: - return False - return connection_string in ( - self.applicationinsights_connection_string - if isinstance(self.applicationinsights_connection_string, list) - else [self.applicationinsights_connection_string] - ) + # 4. Add VS Code extension exporters if port is specified + if self.vs_code_extension_port: + endpoint = f"http://localhost:{self.vs_code_extension_port}" + exporters.extend(_create_otlp_exporters(endpoint=endpoint, protocol="grpc")) + + # 5. Configure providers + self._configure_providers(exporters, views=views) + self._executed_setup = True + + def _configure_providers( + self, + exporters: list["LogRecordExporter | MetricExporter | SpanExporter"], + views: list["View"] | None = None, + ) -> None: + """Configure tracing, logging, events and metrics with the provided exporters. - def _configure_providers(self, exporters: list["LogRecordExporter | MetricExporter | SpanExporter"]) -> None: - """Configure tracing, logging, events and metrics with the provided exporters.""" + Args: + exporters: A list of exporters for logs, metrics and/or spans. + views: Optional list of OpenTelemetry views for metrics. Default is empty list. + """ from opentelemetry._logs import set_logger_provider from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, LogRecordExporter from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import MetricExporter, PeriodicExportingMetricReader - from opentelemetry.sdk.metrics.view import DropAggregation, View from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter + span_exporters: list[SpanExporter] = [] + log_exporters: list[LogRecordExporter] = [] + metric_exporters: list[MetricExporter] = [] + for exp in exporters: + if isinstance(exp, SpanExporter): + span_exporters.append(exp) + if isinstance(exp, LogRecordExporter): + log_exporters.append(exp) + if isinstance(exp, MetricExporter): + metric_exporters.append(exp) + # Tracing - tracer_provider = TracerProvider(resource=self.resource) - trace.set_tracer_provider(tracer_provider) - should_add_console_exporter = True - for exporter in exporters: - if isinstance(exporter, SpanExporter): + if span_exporters: + tracer_provider = TracerProvider(resource=self._resource) + trace.set_tracer_provider(tracer_provider) + for exporter in span_exporters: tracer_provider.add_span_processor(BatchSpanProcessor(exporter)) - should_add_console_exporter = False - if should_add_console_exporter: - from opentelemetry.sdk.trace.export import ConsoleSpanExporter - - tracer_provider.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter())) # Logging - logger_provider = LoggerProvider(resource=self.resource) - should_add_console_exporter = True - for exporter in exporters: - if isinstance(exporter, LogRecordExporter): - logger_provider.add_log_record_processor(BatchLogRecordProcessor(exporter)) - should_add_console_exporter = False - if should_add_console_exporter: - from opentelemetry.sdk._logs.export import ConsoleLogRecordExporter - - logger_provider.add_log_record_processor(BatchLogRecordProcessor(ConsoleLogRecordExporter())) - - # Attach a handler with the provider to the root logger - logger = logging.getLogger() - handler = LoggingHandler(logger_provider=logger_provider) - logger.addHandler(handler) - set_logger_provider(logger_provider) + if log_exporters: + logger_provider = LoggerProvider(resource=self._resource) + for log_exporter in log_exporters: + logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter)) + # Attach a handler with the provider to the root logger + logger = logging.getLogger() + handler = LoggingHandler(logger_provider=logger_provider) + logger.addHandler(handler) + set_logger_provider(logger_provider) # metrics - metric_readers = [ - PeriodicExportingMetricReader(exporter, export_interval_millis=5000) - for exporter in exporters - if isinstance(exporter, MetricExporter) - ] - if not metric_readers: - from opentelemetry.sdk.metrics.export import ConsoleMetricExporter - - metric_readers = [PeriodicExportingMetricReader(ConsoleMetricExporter(), export_interval_millis=5000)] - meter_provider = MeterProvider( - metric_readers=metric_readers, - resource=self.resource, - views=[ - # Dropping all instrument names except for those starting with "agent_framework" - View(instrument_name="*", aggregation=DropAggregation()), - View(instrument_name="agent_framework*"), - View(instrument_name="gen_ai*"), - ], - ) - metrics.set_meter_provider(meter_provider) + if metric_exporters: + meter_provider = MeterProvider( + metric_readers=[ + PeriodicExportingMetricReader(exporter, export_interval_millis=5000) + for exporter in metric_exporters + ], + resource=self._resource, + views=views or [], + ) + metrics.set_meter_provider(meter_provider) def get_tracer( @@ -661,125 +848,174 @@ def get_meter( OBSERVABILITY_SETTINGS: ObservabilitySettings = ObservabilitySettings() -def setup_observability( +def enable_instrumentation( + *, + enable_sensitive_data: bool | None = None, +) -> None: + """Enable instrumentation for your application. + + Calling this method implies you want to enable observability in your application. + + This method does not configure exporters or providers. + It only updates the global variables that trigger the instrumentation code. + If you have already set the environment variable ENABLE_INSTRUMENTATION=true, + calling this method has no effect, unless you want to enable or disable sensitive data events. + + Keyword Args: + enable_sensitive_data: Enable OpenTelemetry sensitive events. Overrides + the environment variable ENABLE_SENSITIVE_DATA if set. Default is None. + """ + global OBSERVABILITY_SETTINGS + OBSERVABILITY_SETTINGS.enable_instrumentation = True + if enable_sensitive_data is not None: + OBSERVABILITY_SETTINGS.enable_sensitive_data = enable_sensitive_data + + +def configure_otel_providers( + *, enable_sensitive_data: bool | None = None, - otlp_endpoint: str | list[str] | None = None, - applicationinsights_connection_string: str | list[str] | None = None, - credential: "TokenCredential | None" = None, exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] | None = None, + views: list["View"] | None = None, vs_code_extension_port: int | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, ) -> None: - """Setup observability for the application with OpenTelemetry. + """Configure otel providers and enable instrumentation for the application with OpenTelemetry. This method creates the exporters and providers for the application based on - the provided values and environment variables. + the provided values and environment variables and enables instrumentation. Call this method once during application startup, before any telemetry is captured. DO NOT call this method multiple times, as it may lead to unexpected behavior. + The function automatically reads standard OpenTelemetry environment variables: + - OTEL_EXPORTER_OTLP_ENDPOINT: Base OTLP endpoint for all signals + - OTEL_EXPORTER_OTLP_TRACES_ENDPOINT: OTLP endpoint for traces + - OTEL_EXPORTER_OTLP_METRICS_ENDPOINT: OTLP endpoint for metrics + - OTEL_EXPORTER_OTLP_LOGS_ENDPOINT: OTLP endpoint for logs + - OTEL_EXPORTER_OTLP_PROTOCOL: Protocol (grpc/http) + - OTEL_EXPORTER_OTLP_HEADERS: Headers for all signals + - ENABLE_CONSOLE_EXPORTERS: Enable console output for telemetry + Note: - If you have configured the providers manually, calling this method will not - have any effect. The reverse is also true - if you call this method first, - subsequent provider configurations will not take effect. + Since you can only setup one provider per signal type (logs, traces, metrics), + you can choose to use this method and take the exporter and provider that we created. + Alternatively, you can setup the providers yourself, or through another library + (e.g., Azure Monitor) and just call `enable_instrumentation()` to enable instrumentation. - Args: + Note: + By default, the Agent Framework emits metrics with the prefixes `agent_framework` + and `gen_ai` (OpenTelemetry GenAI semantic conventions). You can use the `views` + parameter to filter which metrics are collected and exported. You can also use + the `create_metric_views()` helper function to get default views. + + Keyword Args: enable_sensitive_data: Enable OpenTelemetry sensitive events. Overrides - the environment variable if set. Default is None. - otlp_endpoint: The OpenTelemetry Protocol (OTLP) endpoint. Will be used - to create OTLPLogExporter, OTLPMetricExporter and OTLPSpanExporter. - Default is None. - applicationinsights_connection_string: The Azure Monitor connection string. - Will be used to create AzureMonitorExporters. Default is None. - credential: The credential to use for Azure Monitor Entra ID authentication. + the environment variable ENABLE_SENSITIVE_DATA if set. Default is None. + exporters: A list of custom exporters for logs, metrics or spans, or any combination. + These will be added in addition to exporters configured via environment variables. Default is None. - exporters: A list of exporters for logs, metrics or spans, or any combination. - These will be added directly, allowing complete customization. Default is None. - vs_code_extension_port: The port the AI Toolkit or AzureAI Foundry VS Code + views: Optional list of OpenTelemetry views for metrics configuration. + Views allow filtering and customizing which metrics are collected. + Default is None (empty list). + vs_code_extension_port: The port the AI Toolkit or Azure AI Foundry VS Code extensions are listening on. When set, additional OTEL exporters will be - created with endpoint `http://localhost:{vs_code_extension_port}` unless - already configured. Overrides the environment variable if set. Default is None. + created with endpoint `http://localhost:{vs_code_extension_port}`. + Overrides the environment variable VS_CODE_EXTENSION_PORT if set. Default is None. + env_file_path: An optional path to a .env file to load environment variables from. + Default is None. + env_file_encoding: The encoding to use when loading the .env file. Default is None + which uses the system default encoding. Examples: .. code-block:: python - from agent_framework import setup_observability - - # With environment variables - # Set ENABLE_OTEL=true, OTLP_ENDPOINT=http://localhost:4317 - setup_observability() + from agent_framework.observability import configure_otel_providers - # With parameters (no environment variables) - setup_observability( - enable_sensitive_data=True, - otlp_endpoint="http://localhost:4317", - ) + # Using environment variables (recommended) + # Set ENABLE_INSTRUMENTATION=true + # Set OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4317 + configure_otel_providers() - # With Azure Monitor - setup_observability( - applicationinsights_connection_string="InstrumentationKey=...", - ) + # Enable console output for debugging + # Set ENABLE_CONSOLE_EXPORTERS=true + configure_otel_providers() # With custom exporters - from opentelemetry.sdk.trace.export import ConsoleSpanExporter + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter - setup_observability( - exporters=[ConsoleSpanExporter()], - ) - - # Mixed: combine environment variables and parameters - # Environment: OTLP_ENDPOINT=http://localhost:7431 - # Code adds additional endpoint - setup_observability( - enable_sensitive_data=True, - otlp_endpoint="http://localhost:4317", # Both endpoints will be used + configure_otel_providers( + exporters=[ + OTLPSpanExporter(endpoint="http://custom:4317"), + OTLPLogExporter(endpoint="http://custom:4317"), + ], ) # VS Code extension integration - setup_observability( + configure_otel_providers( vs_code_extension_port=4317, # Connects to AI Toolkit ) - """ - global OBSERVABILITY_SETTINGS - # Update the observability settings with the provided values - OBSERVABILITY_SETTINGS.enable_otel = True - if enable_sensitive_data is not None: - OBSERVABILITY_SETTINGS.enable_sensitive_data = enable_sensitive_data - if vs_code_extension_port is not None: - OBSERVABILITY_SETTINGS.vs_code_extension_port = vs_code_extension_port - - # Create exporters, after checking if they are already configured through the env. - new_exporters: list["LogRecordExporter | SpanExporter | MetricExporter"] = exporters or [] - if otlp_endpoint: - if isinstance(otlp_endpoint, str): - otlp_endpoint = [otlp_endpoint] - new_exporters.extend( - _get_otlp_exporters( - endpoints=[ - endpoint - for endpoint in otlp_endpoint - if not OBSERVABILITY_SETTINGS.check_endpoint_already_configured(endpoint) - ] + + # Enable sensitive data logging (development only) + configure_otel_providers( + enable_sensitive_data=True, ) - ) - if applicationinsights_connection_string: - if isinstance(applicationinsights_connection_string, str): - applicationinsights_connection_string = [applicationinsights_connection_string] - new_exporters.extend( - _get_azure_monitor_exporters( - connection_strings=[ - conn_str - for conn_str in applicationinsights_connection_string - if not OBSERVABILITY_SETTINGS.check_connection_string_already_configured(conn_str) + + # With custom metrics views + from opentelemetry.sdk.metrics.view import View + + configure_otel_providers( + views=[ + View(instrument_name="agent_framework*"), + View(instrument_name="gen_ai*"), ], - credential=credential, ) - ) - if OBSERVABILITY_SETTINGS.vs_code_extension_port: - endpoint = f"http://localhost:{OBSERVABILITY_SETTINGS.vs_code_extension_port}" - if not OBSERVABILITY_SETTINGS.check_endpoint_already_configured(endpoint): - new_exporters.extend(_get_otlp_exporters(endpoints=[endpoint])) - OBSERVABILITY_SETTINGS._configure(credential=credential, additional_exporters=new_exporters) # pyright: ignore[reportPrivateUsage] + This example shows how to first setup your providers, + and then ensure Agent Framework emits traces, logs and metrics + + .. code-block:: python + + # when azure monitor is installed + from agent_framework.observability import enable_instrumentation + from azure.monitor.opentelemetry import configure_azure_monitor + + connection_string = "InstrumentationKey=your_instrumentation_key_here;..." + configure_azure_monitor(connection_string=connection_string) + enable_instrumentation() + + References: + - https://opentelemetry.io/docs/languages/sdk-configuration/general/ + - https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter/ + """ + global OBSERVABILITY_SETTINGS + if env_file_path: + # Build kwargs, excluding None values + settings_kwargs: dict[str, Any] = { + "enable_instrumentation": True, + "env_file_path": env_file_path, + } + if env_file_encoding is not None: + settings_kwargs["env_file_encoding"] = env_file_encoding + if enable_sensitive_data is not None: + settings_kwargs["enable_sensitive_data"] = enable_sensitive_data + if vs_code_extension_port is not None: + settings_kwargs["vs_code_extension_port"] = vs_code_extension_port + + OBSERVABILITY_SETTINGS = ObservabilitySettings(**settings_kwargs) + else: + # Update the observability settings with the provided values + OBSERVABILITY_SETTINGS.enable_instrumentation = True + if enable_sensitive_data is not None: + OBSERVABILITY_SETTINGS.enable_sensitive_data = enable_sensitive_data + if vs_code_extension_port is not None: + OBSERVABILITY_SETTINGS.vs_code_extension_port = vs_code_extension_port + + OBSERVABILITY_SETTINGS._configure( # type: ignore[reportPrivateUsage] + additional_exporters=exporters, + views=views, + ) # region Chat Client Telemetry @@ -993,7 +1229,7 @@ async def trace_get_streaming_response( return decorator(func) -def use_observability( +def use_instrumentation( chat_client: type[TChatClient], ) -> type[TChatClient]: """Class decorator that enables OpenTelemetry observability for a chat client. @@ -1019,12 +1255,12 @@ def use_observability( Examples: .. code-block:: python - from agent_framework import use_observability, setup_observability + from agent_framework import use_instrumentation, configure_otel_providers from agent_framework import ChatClientProtocol # Decorate a custom chat client class - @use_observability + @use_instrumentation class MyCustomChatClient: OTEL_PROVIDER_NAME = "my_provider" @@ -1038,7 +1274,7 @@ async def get_streaming_response(self, messages, **kwargs): # Setup observability - setup_observability(otlp_endpoint="http://localhost:4317") + configure_otel_providers(otlp_endpoint="http://localhost:4317") # Now all calls will be traced client = MyCustomChatClient() @@ -1082,12 +1318,14 @@ async def get_streaming_response(self, messages, **kwargs): def _trace_agent_run( run_func: Callable[..., Awaitable["AgentRunResponse"]], provider_name: str, + capture_usage: bool = True, ) -> Callable[..., Awaitable["AgentRunResponse"]]: """Decorator to trace chat completion activities. Args: run_func: The function to trace. provider_name: The system name used for Open Telemetry. + capture_usage: Whether to capture token usage as a span attribute. """ @wraps(run_func) @@ -1128,7 +1366,7 @@ async def trace_run( capture_exception(span=span, exception=exception, timestamp=time_ns()) raise else: - attributes = _get_response_attributes(attributes, response) + attributes = _get_response_attributes(attributes, response, capture_usage=capture_usage) _capture_response(span=span, attributes=attributes) if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages: _capture_messages( @@ -1145,12 +1383,14 @@ async def trace_run( def _trace_agent_run_stream( run_streaming_func: Callable[..., AsyncIterable["AgentRunResponseUpdate"]], provider_name: str, + capture_usage: bool, ) -> Callable[..., AsyncIterable["AgentRunResponseUpdate"]]: """Decorator to trace streaming agent run activities. Args: run_streaming_func: The function to trace. provider_name: The system name used for Open Telemetry. + capture_usage: Whether to capture token usage as a span attribute. """ @wraps(run_streaming_func) @@ -1201,7 +1441,7 @@ async def trace_run_streaming( raise else: response = AgentRunResponse.from_agent_run_response_updates(all_updates) - attributes = _get_response_attributes(attributes, response) + attributes = _get_response_attributes(attributes, response, capture_usage=capture_usage) _capture_response(span=span, attributes=attributes) if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages: _capture_messages( @@ -1214,9 +1454,11 @@ async def trace_run_streaming( return trace_run_streaming -def use_agent_observability( - agent: type[TAgent], -) -> type[TAgent]: +def use_agent_instrumentation( + agent: type[TAgent] | None = None, + *, + capture_usage: bool = True, +) -> type[TAgent] | Callable[[type[TAgent]], type[TAgent]]: """Class decorator that enables OpenTelemetry observability for an agent. This decorator automatically traces agent run requests, captures events, @@ -1224,12 +1466,17 @@ def use_agent_observability( Note: This decorator must be applied to the agent class itself, not an instance. - The agent class should have a class variable AGENT_SYSTEM_NAME to set the + The agent class should have a class variable AGENT_PROVIDER_NAME to set the proper system name for telemetry. Args: agent: The agent class to enable observability for. + Keyword Args: + capture_usage: Whether to capture token usage as a span attribute. + Defaults to True, set to False when the agent has underlying traces + that already capture token usage to avoid double counting. + Returns: The decorated agent class with observability enabled. @@ -1240,14 +1487,14 @@ def use_agent_observability( Examples: .. code-block:: python - from agent_framework import use_agent_observability, setup_observability + from agent_framework import use_agent_instrumentation, configure_otel_providers from agent_framework._agents import AgentProtocol # Decorate a custom agent class - @use_agent_observability + @use_agent_instrumentation class MyCustomAgent: - AGENT_SYSTEM_NAME = "my_agent_system" + AGENT_PROVIDER_NAME = "my_agent_system" async def run(self, messages=None, *, thread=None, **kwargs): # Your implementation @@ -1259,23 +1506,31 @@ async def run_stream(self, messages=None, *, thread=None, **kwargs): # Setup observability - setup_observability(otlp_endpoint="http://localhost:4317") + configure_otel_providers(otlp_endpoint="http://localhost:4317") # Now all agent runs will be traced agent = MyCustomAgent() response = await agent.run("Perform a task") """ - provider_name = str(getattr(agent, "AGENT_SYSTEM_NAME", "Unknown")) - try: - agent.run = _trace_agent_run(agent.run, provider_name) # type: ignore - except AttributeError as exc: - raise AgentInitializationError(f"The agent {agent.__name__} does not have a run method.", exc) from exc - try: - agent.run_stream = _trace_agent_run_stream(agent.run_stream, provider_name) # type: ignore - except AttributeError as exc: - raise AgentInitializationError(f"The agent {agent.__name__} does not have a run_stream method.", exc) from exc - setattr(agent, OPEN_TELEMETRY_AGENT_MARKER, True) - return agent + + def decorator(agent: type[TAgent]) -> type[TAgent]: + provider_name = str(getattr(agent, "AGENT_PROVIDER_NAME", "Unknown")) + try: + agent.run = _trace_agent_run(agent.run, provider_name, capture_usage=capture_usage) # type: ignore + except AttributeError as exc: + raise AgentInitializationError(f"The agent {agent.__name__} does not have a run method.", exc) from exc + try: + agent.run_stream = _trace_agent_run_stream(agent.run_stream, provider_name, capture_usage=capture_usage) # type: ignore + except AttributeError as exc: + raise AgentInitializationError( + f"The agent {agent.__name__} does not have a run_stream method.", exc + ) from exc + setattr(agent, OPEN_TELEMETRY_AGENT_MARKER, True) + return agent + + if agent is None: + return decorator + return decorator(agent) # region Otel Helpers @@ -1458,26 +1713,32 @@ def _to_otel_part(content: "Contents") -> dict[str, Any] | None: match content.type: case "text": return {"type": "text", "content": content.text} + case "text_reasoning": + return {"type": "reasoning", "content": content.text} + case "uri": + return { + "type": "uri", + "uri": content.uri, + "mime_type": content.media_type, + "modality": content.media_type.split("/")[0] if content.media_type else None, + } + case "data": + return { + "type": "blob", + "content": content.get_data_bytes_as_str(), + "mime_type": content.media_type, + "modality": content.media_type.split("/")[0] if content.media_type else None, + } case "function_call": return {"type": "tool_call", "id": content.call_id, "name": content.name, "arguments": content.arguments} case "function_result": - response: Any | None = None - if content.result: - if isinstance(content.result, list): - res: list[Any] = [] - for item in content.result: # type: ignore - from ._types import BaseContent - - if isinstance(item, BaseContent): - res.append(_to_otel_part(item)) # type: ignore - elif isinstance(item, BaseModel): - res.append(item.model_dump(exclude_none=True)) - else: - res.append(json.dumps(item, default=str)) - response = json.dumps(res, default=str) - else: - response = json.dumps(content.result, default=str) - return {"type": "tool_call_response", "id": content.call_id, "response": response} + from ._types import prepare_function_call_results + + return { + "type": "tool_call_response", + "id": content.call_id, + "response": prepare_function_call_results(content), + } case _: # GenericPart in otel output messages json spec. # just required type, and arbitrary other fields. @@ -1489,6 +1750,8 @@ def _get_response_attributes( attributes: dict[str, Any], response: "ChatResponse | AgentRunResponse", duration: float | None = None, + *, + capture_usage: bool = True, ) -> dict[str, Any]: """Get the response attributes from a response.""" if response.response_id: @@ -1502,7 +1765,7 @@ def _get_response_attributes( attributes[OtelAttr.FINISH_REASONS] = json.dumps([finish_reason.value]) if model_id := getattr(response, "model_id", None): attributes[SpanAttributes.LLM_RESPONSE_MODEL] = model_id - if usage := response.usage_details: + if capture_usage and (usage := response.usage_details): if usage.input_token_count: attributes[OtelAttr.INPUT_TOKENS] = usage.input_token_count if usage.output_token_count: diff --git a/python/packages/core/agent_framework/openai/_assistants_client.py b/python/packages/core/agent_framework/openai/_assistants_client.py index 0f3bb3de63..319ad95231 100644 --- a/python/packages/core/agent_framework/openai/_assistants_client.py +++ b/python/packages/core/agent_framework/openai/_assistants_client.py @@ -40,7 +40,7 @@ prepare_function_call_results, ) from ..exceptions import ServiceInitializationError -from ..observability import use_observability +from ..observability import use_instrumentation from ._shared import OpenAIConfigMixin, OpenAISettings if sys.version_info >= (3, 11): @@ -53,7 +53,7 @@ @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient): """OpenAI Assistants client.""" diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index 73605fadef..7f0feb0fc7 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -44,7 +44,7 @@ ServiceInvalidRequestError, ServiceResponseException, ) -from ..observability import use_observability +from ..observability import use_instrumentation from ._exceptions import OpenAIContentFilterException from ._shared import OpenAIBase, OpenAIConfigMixin, OpenAISettings @@ -162,6 +162,7 @@ def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options: exclude={ "type", "instructions", # included as system message + "allow_multiple_tool_calls", # handled separately } ) @@ -174,6 +175,8 @@ def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options: if web_search_options: options_dict["web_search_options"] = web_search_options options_dict["tools"] = self._chat_to_tool_spec(chat_options.tools) + if chat_options.allow_multiple_tool_calls is not None: + options_dict["parallel_tool_calls"] = chat_options.allow_multiple_tool_calls if not options_dict.get("tools", None): options_dict.pop("tools", None) options_dict.pop("parallel_tool_calls", None) @@ -467,7 +470,7 @@ def service_url(self) -> str: @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient): """OpenAI Chat completion class.""" diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index d1857fb4fe..ecdd7be660 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -3,7 +3,7 @@ from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence, Sequence from datetime import datetime, timezone from itertools import chain -from typing import Any, TypeVar +from typing import Any, TypeVar, cast from openai import AsyncOpenAI, BadRequestError from openai.types.responses.file_search_tool_param import FileSearchToolParam @@ -64,7 +64,7 @@ ServiceInvalidRequestError, ServiceResponseException, ) -from ..observability import use_observability +from ..observability import use_instrumentation from ._exceptions import OpenAIContentFilterException from ._shared import OpenAIBase, OpenAIConfigMixin, OpenAISettings @@ -199,7 +199,7 @@ def _prepare_text_config( return response_format, prepared_text if isinstance(response_format, Mapping): - format_config = self._convert_response_format(response_format) + format_config = self._convert_response_format(cast("Mapping[str, Any]", response_format)) if prepared_text is None: prepared_text = {} elif "format" in prepared_text and prepared_text["format"] != format_config: @@ -212,20 +212,21 @@ def _prepare_text_config( def _convert_response_format(self, response_format: Mapping[str, Any]) -> dict[str, Any]: """Convert Chat style response_format into Responses text format config.""" if "format" in response_format and isinstance(response_format["format"], Mapping): - return dict(response_format["format"]) + return dict(cast("Mapping[str, Any]", response_format["format"])) format_type = response_format.get("type") if format_type == "json_schema": schema_section = response_format.get("json_schema", response_format) if not isinstance(schema_section, Mapping): raise ServiceInvalidRequestError("json_schema response_format must be a mapping.") - schema = schema_section.get("schema") + schema_section_typed = cast("Mapping[str, Any]", schema_section) + schema: Any = schema_section_typed.get("schema") if schema is None: raise ServiceInvalidRequestError("json_schema response_format requires a schema.") - name = ( - schema_section.get("name") - or schema_section.get("title") - or (schema.get("title") if isinstance(schema, Mapping) else None) + name: str = str( + schema_section_typed.get("name") + or schema_section_typed.get("title") + or (cast("Mapping[str, Any]", schema).get("title") if isinstance(schema, Mapping) else None) or "response" ) format_config: dict[str, Any] = { @@ -532,12 +533,13 @@ def _openai_content_parser( "text": content.text, }, } - if content.additional_properties is not None: - if status := content.additional_properties.get("status"): + props: dict[str, Any] | None = getattr(content, "additional_properties", None) + if props: + if status := props.get("status"): ret["status"] = status - if reasoning_text := content.additional_properties.get("reasoning_text"): + if reasoning_text := props.get("reasoning_text"): ret["content"] = {"type": "reasoning_text", "text": reasoning_text} - if encrypted_content := content.additional_properties.get("encrypted_content"): + if encrypted_content := props.get("encrypted_content"): ret["encrypted_content"] = encrypted_content return ret case DataContent() | UriContent(): @@ -824,7 +826,7 @@ def _create_response_content( "raw_representation": response, } - conversation_id = self.get_conversation_id(response, chat_options.store) + conversation_id = self.get_conversation_id(response, chat_options.store) # type: ignore[reportArgumentType] if conversation_id: args["conversation_id"] = conversation_id @@ -911,6 +913,8 @@ def _create_streaming_response_content( metadata.update(self._get_metadata_from_response(event_part)) case "refusal": contents.append(TextContent(text=event_part.refusal, raw_representation=event)) + case _: + pass case "response.output_text.delta": contents.append(TextContent(text=event.delta, raw_representation=event)) metadata.update(self._get_metadata_from_response(event)) @@ -1032,6 +1036,60 @@ def _create_streaming_response_content( raw_representation=event, ) ) + case "response.output_text.annotation.added": + # Handle streaming text annotations (file citations, file paths, etc.) + annotation: Any = event.annotation + + def _get_ann_value(key: str) -> Any: + """Extract value from annotation (dict or object).""" + if isinstance(annotation, dict): + return cast("dict[str, Any]", annotation).get(key) + return getattr(annotation, key, None) + + ann_type = _get_ann_value("type") + ann_file_id = _get_ann_value("file_id") + if ann_type == "file_path": + if ann_file_id: + contents.append( + HostedFileContent( + file_id=str(ann_file_id), + additional_properties={ + "annotation_index": event.annotation_index, + "index": _get_ann_value("index"), + }, + raw_representation=event, + ) + ) + elif ann_type == "file_citation": + if ann_file_id: + contents.append( + HostedFileContent( + file_id=str(ann_file_id), + additional_properties={ + "annotation_index": event.annotation_index, + "filename": _get_ann_value("filename"), + "index": _get_ann_value("index"), + }, + raw_representation=event, + ) + ) + elif ann_type == "container_file_citation": + if ann_file_id: + contents.append( + HostedFileContent( + file_id=str(ann_file_id), + additional_properties={ + "annotation_index": event.annotation_index, + "container_id": _get_ann_value("container_id"), + "filename": _get_ann_value("filename"), + "start_index": _get_ann_value("start_index"), + "end_index": _get_ann_value("end_index"), + }, + raw_representation=event, + ) + ) + else: + logger.debug("Unparsed annotation type in streaming: %s", ann_type) case _: logger.debug("Unparsed event of type: %s: %s", event.type, event) @@ -1069,7 +1127,7 @@ def _get_metadata_from_response(self, output: Any) -> dict[str, Any]: @use_function_invocation -@use_observability +@use_instrumentation @use_chat_middleware class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient): """OpenAI Responses client class.""" diff --git a/python/packages/core/agent_framework/openai/_shared.py b/python/packages/core/agent_framework/openai/_shared.py index 511c1f3379..e0df8844e4 100644 --- a/python/packages/core/agent_framework/openai/_shared.py +++ b/python/packages/core/agent_framework/openai/_shared.py @@ -63,7 +63,7 @@ def _check_openai_version_for_callable_api_key() -> None: raise ServiceInitializationError( f"Callable API keys require OpenAI SDK >= 1.106.0, but you have {openai.__version__}. " f"Please upgrade with 'pip install openai>=1.106.0' or provide a string API key instead. " - f"Note: If you're using mem0ai, you may need to upgrade to mem0ai>=0.1.118 " + f"Note: If you're using mem0ai, you may need to upgrade to mem0ai>=1.0.0 " f"to allow newer OpenAI versions." ) except ServiceInitializationError: diff --git a/python/packages/core/pyproject.toml b/python/packages/core/pyproject.toml index aa8e2611f5..8eec13e8e6 100644 --- a/python/packages/core/pyproject.toml +++ b/python/packages/core/pyproject.toml @@ -4,7 +4,7 @@ description = "Microsoft Agent Framework for building AI Agents with Python. Thi authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251209" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" @@ -30,12 +30,11 @@ dependencies = [ # telemetry "opentelemetry-api>=1.39.0", "opentelemetry-sdk>=1.39.0", - "opentelemetry-exporter-otlp-proto-grpc>=1.39.0", "opentelemetry-semantic-conventions-ai>=0.4.13", # connectors and functions "openai>=1.99.0", "azure-identity>=1,<2", - "mcp[ws]>=1.13", + "mcp[ws]>=1.23", "packaging>=24.1", ] diff --git a/python/packages/core/tests/conftest.py b/python/packages/core/tests/conftest.py index d356e300bb..fd8b93ebc2 100644 --- a/python/packages/core/tests/conftest.py +++ b/python/packages/core/tests/conftest.py @@ -10,7 +10,7 @@ @fixture -def enable_otel(request: Any) -> bool: +def enable_instrumentation(request: Any) -> bool: """Fixture that returns a boolean indicating if Otel is enabled.""" return request.param if hasattr(request, "param") else True @@ -22,20 +22,31 @@ def enable_sensitive_data(request: Any) -> bool: @fixture -def span_exporter(monkeypatch, enable_otel: bool, enable_sensitive_data: bool) -> Generator[SpanExporter]: +def span_exporter(monkeypatch, enable_instrumentation: bool, enable_sensitive_data: bool) -> Generator[SpanExporter]: """Fixture to remove environment variables for ObservabilitySettings.""" env_vars = [ - "ENABLE_OTEL", + "ENABLE_INSTRUMENTATION", "ENABLE_SENSITIVE_DATA", - "OTLP_ENDPOINT", - "APPLICATIONINSIGHTS_CONNECTION_STRING", + "ENABLE_CONSOLE_EXPORTERS", + "OTEL_EXPORTER_OTLP_ENDPOINT", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", + "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", + "OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", + "OTEL_EXPORTER_OTLP_PROTOCOL", + "OTEL_EXPORTER_OTLP_HEADERS", + "OTEL_EXPORTER_OTLP_TRACES_HEADERS", + "OTEL_EXPORTER_OTLP_METRICS_HEADERS", + "OTEL_EXPORTER_OTLP_LOGS_HEADERS", + "OTEL_SERVICE_NAME", + "OTEL_SERVICE_VERSION", + "OTEL_RESOURCE_ATTRIBUTES", ] for key in env_vars: monkeypatch.delenv(key, raising=False) # type: ignore - monkeypatch.setenv("ENABLE_OTEL", str(enable_otel)) # type: ignore - if not enable_otel: + monkeypatch.setenv("ENABLE_INSTRUMENTATION", str(enable_instrumentation)) # type: ignore + if not enable_instrumentation: # we overwrite sensitive data for tests enable_sensitive_data = False monkeypatch.setenv("ENABLE_SENSITIVE_DATA", str(enable_sensitive_data)) # type: ignore @@ -51,15 +62,22 @@ def span_exporter(monkeypatch, enable_otel: bool, enable_sensitive_data: bool) - # recreate observability settings with values from above and no file. observability_settings = observability.ObservabilitySettings(env_file_path="test.env") - observability_settings._configure() # pyright: ignore[reportPrivateUsage] + + # Configure providers manually without calling _configure() to avoid OTLP imports + if enable_instrumentation or enable_sensitive_data: + from opentelemetry.sdk.trace import TracerProvider + + tracer_provider = TracerProvider(resource=observability_settings._resource) + trace.set_tracer_provider(tracer_provider) + monkeypatch.setattr(observability, "OBSERVABILITY_SETTINGS", observability_settings, raising=False) # type: ignore with ( patch("agent_framework.observability.OBSERVABILITY_SETTINGS", observability_settings), - patch("agent_framework.observability.setup_observability"), + patch("agent_framework.observability.configure_otel_providers"), ): exporter = InMemorySpanExporter() - if enable_otel or enable_sensitive_data: + if enable_instrumentation or enable_sensitive_data: tracer_provider = trace.get_tracer_provider() if not hasattr(tracer_provider, "add_span_processor"): raise RuntimeError("Tracer provider does not support adding span processors.") diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 77d5911865..a6df07cbbe 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -21,9 +21,11 @@ ChatResponse, Context, ContextProvider, + FunctionCallContent, HostedCodeInterpreterTool, Role, TextContent, + ai_function, ) from agent_framework._mcp import MCPTool from agent_framework.exceptions import AgentExecutionException @@ -595,3 +597,38 @@ async def test_chat_agent_with_local_mcp_tools(chat_client: ChatClientProtocol) # Test async context manager with MCP tools async with agent: pass + + +async def test_agent_tool_receives_thread_in_kwargs(chat_client_base: Any) -> None: + """Verify tool execution receives 'thread' inside **kwargs when function is called by client.""" + + captured: dict[str, Any] = {} + + @ai_function(name="echo_thread_info") + def echo_thread_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnknownParameterType] + thread = kwargs.get("thread") + captured["has_thread"] = thread is not None + captured["has_message_store"] = thread.message_store is not None if isinstance(thread, AgentThread) else False + return f"echo: {text}" + + # Make the base client emit a function call for our tool + chat_client_base.run_responses = [ + ChatResponse( + messages=ChatMessage( + role="assistant", + contents=[FunctionCallContent(call_id="1", name="echo_thread_info", arguments='{"text": "hello"}')], + ) + ), + ChatResponse(messages=ChatMessage(role="assistant", text="done")), + ] + + agent = ChatAgent( + chat_client=chat_client_base, tools=[echo_thread_info], chat_message_store_factory=ChatMessageStore + ) + thread = agent.get_new_thread() + + result = await agent.run("hello", thread=thread) + + assert result.text == "done" + assert captured.get("has_thread") is True + assert captured.get("has_message_store") is True diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 5a0ec5a773..bc96ddcc35 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +from collections.abc import Awaitable, Callable + import pytest from agent_framework import ( @@ -16,6 +18,7 @@ TextContent, ai_function, ) +from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware async def test_base_client_with_function_calling(chat_client_base: ChatClientProtocol): @@ -2206,3 +2209,175 @@ def sometimes_fails(arg1: str) -> str: assert len(error_results) >= 1 assert len(success_results) >= 1 assert call_count == 2 # Both calls executed + + +class TerminateLoopMiddleware(FunctionMiddleware): + """Middleware that sets terminate=True to exit the function calling loop.""" + + async def process( + self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + # Set result to a simple value - the framework will wrap it in FunctionResultContent + context.result = "terminated by middleware" + context.terminate = True + + +async def test_terminate_loop_single_function_call(chat_client_base: ChatClientProtocol): + """Test that terminate_loop=True exits the function calling loop after single function call.""" + exec_counter = 0 + + @ai_function(name="test_function") + def ai_func(arg1: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Processed {arg1}" + + # Queue up two responses: function call, then final text + # If terminate_loop works, only the first response should be consumed + chat_client_base.run_responses = [ + ChatResponse( + messages=ChatMessage( + role="assistant", + contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + ) + ), + ChatResponse(messages=ChatMessage(role="assistant", text="done")), + ] + + response = await chat_client_base.get_response( + "hello", + tool_choice="auto", + tools=[ai_func], + middleware=[TerminateLoopMiddleware()], + ) + + # Function should NOT have been executed - middleware intercepted it + assert exec_counter == 0 + + # There should be 2 messages: assistant with function call, tool result from middleware + # The loop should NOT have continued to call the LLM again + assert len(response.messages) == 2 + assert response.messages[0].role == Role.ASSISTANT + assert isinstance(response.messages[0].contents[0], FunctionCallContent) + assert response.messages[1].role == Role.TOOL + assert isinstance(response.messages[1].contents[0], FunctionResultContent) + assert response.messages[1].contents[0].result == "terminated by middleware" + + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client_base.run_responses) == 1 + + +class SelectiveTerminateMiddleware(FunctionMiddleware): + """Only terminates for terminating_function.""" + + async def process( + self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + if context.function.name == "terminating_function": + # Set result to a simple value - the framework will wrap it in FunctionResultContent + context.result = "terminated by middleware" + context.terminate = True + else: + await next_handler(context) + + +async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client_base: ChatClientProtocol): + """Test that any(terminate_loop=True) exits loop even with multiple function calls.""" + normal_call_count = 0 + terminating_call_count = 0 + + @ai_function(name="normal_function") + def normal_func(arg1: str) -> str: + nonlocal normal_call_count + normal_call_count += 1 + return f"Normal {arg1}" + + @ai_function(name="terminating_function") + def terminating_func(arg1: str) -> str: + nonlocal terminating_call_count + terminating_call_count += 1 + return f"Terminating {arg1}" + + # Queue up two responses: parallel function calls, then final text + chat_client_base.run_responses = [ + ChatResponse( + messages=ChatMessage( + role="assistant", + contents=[ + FunctionCallContent(call_id="1", name="normal_function", arguments='{"arg1": "value1"}'), + FunctionCallContent(call_id="2", name="terminating_function", arguments='{"arg1": "value2"}'), + ], + ) + ), + ChatResponse(messages=ChatMessage(role="assistant", text="done")), + ] + + response = await chat_client_base.get_response( + "hello", + tool_choice="auto", + tools=[normal_func, terminating_func], + middleware=[SelectiveTerminateMiddleware()], + ) + + # normal_function should have executed (middleware calls next_handler) + # terminating_function should NOT have executed (middleware intercepts it) + assert normal_call_count == 1 + assert terminating_call_count == 0 + + # There should be 2 messages: assistant with function calls, tool results + # The loop should NOT have continued to call the LLM again + assert len(response.messages) == 2 + assert response.messages[0].role == Role.ASSISTANT + assert len(response.messages[0].contents) == 2 + assert response.messages[1].role == Role.TOOL + # Both function results should be present + assert len(response.messages[1].contents) == 2 + + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client_base.run_responses) == 1 + + +async def test_terminate_loop_streaming_single_function_call(chat_client_base: ChatClientProtocol): + """Test that terminate_loop=True exits the streaming function calling loop.""" + exec_counter = 0 + + @ai_function(name="test_function") + def ai_func(arg1: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Processed {arg1}" + + # Queue up two streaming responses + chat_client_base.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + role="assistant", + ), + ], + [ + ChatResponseUpdate( + contents=[TextContent(text="done")], + role="assistant", + ) + ], + ] + + updates = [] + async for update in chat_client_base.get_streaming_response( + "hello", + tool_choice="auto", + tools=[ai_func], + middleware=[TerminateLoopMiddleware()], + ): + updates.append(update) + + # Function should NOT have been executed - middleware intercepted it + assert exec_counter == 0 + + # Should have function call update and function result update + # The loop should NOT have continued to call the LLM again + assert len(updates) == 2 + + # Verify the second streaming response is still in the queue (wasn't consumed) + assert len(chat_client_base.streaming_responses) == 1 diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 264ff1929b..93643da30f 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -9,7 +9,7 @@ from mcp import types from mcp.client.session import ClientSession from mcp.shared.exceptions import McpError -from pydantic import AnyUrl, ValidationError +from pydantic import AnyUrl, BaseModel, ValidationError from agent_framework import ( ChatMessage, @@ -75,17 +75,21 @@ def test_mcp_call_tool_result_to_ai_contents(): mcp_result = types.CallToolResult( content=[ types.TextContent(type="text", text="Result text"), - types.ImageContent(type="image", data="", mimeType="image/png"), + types.ImageContent(type="image", data="xyz", mimeType="image/png"), + types.ImageContent(type="image", data=b"abc", mimeType="image/webp"), ] ) ai_contents = _mcp_call_tool_result_to_ai_contents(mcp_result) - assert len(ai_contents) == 2 + assert len(ai_contents) == 3 assert isinstance(ai_contents[0], TextContent) assert ai_contents[0].text == "Result text" assert isinstance(ai_contents[1], DataContent) assert ai_contents[1].uri == "" assert ai_contents[1].media_type == "image/png" + assert isinstance(ai_contents[2], DataContent) + assert ai_contents[2].uri == "" + assert ai_contents[2].media_type == "image/webp" def test_mcp_call_tool_result_with_meta_error(): @@ -183,7 +187,7 @@ def test_mcp_call_tool_result_regression_successful_workflow(): mcp_result = types.CallToolResult( content=[ types.TextContent(type="text", text="Success message"), - types.ImageContent(type="image", data="", mimeType="image/jpeg"), + types.ImageContent(type="image", data="abc123", mimeType="image/jpeg"), ] ) @@ -218,7 +222,8 @@ def test_mcp_content_types_to_ai_content_text(): def test_mcp_content_types_to_ai_content_image(): """Test conversion of MCP image content to AI content.""" - mcp_content = types.ImageContent(type="image", data="", mimeType="image/jpeg") + mcp_content = types.ImageContent(type="image", data="abc", mimeType="image/jpeg") + mcp_content = types.ImageContent(type="image", data=b"abc", mimeType="image/jpeg") ai_content = _mcp_type_to_ai_content(mcp_content)[0] assert isinstance(ai_content, DataContent) @@ -229,7 +234,7 @@ def test_mcp_content_types_to_ai_content_image(): def test_mcp_content_types_to_ai_content_audio(): """Test conversion of MCP audio content to AI content.""" - mcp_content = types.AudioContent(type="audio", data="data:audio/wav;base64,def", mimeType="audio/wav") + mcp_content = types.AudioContent(type="audio", data="def", mimeType="audio/wav") ai_content = _mcp_type_to_ai_content(mcp_content)[0] assert isinstance(ai_content, DataContent) @@ -357,122 +362,360 @@ def test_chat_message_to_mcp_types(): assert isinstance(mcp_contents[1], types.ImageContent) -def test_get_input_model_from_mcp_tool(): - """Test creation of input model from MCP tool.""" - tool = types.Tool( - name="test_tool", - description="A test tool", - inputSchema={ - "type": "object", - "properties": {"param1": {"type": "string"}, "param2": {"type": "number"}}, - "required": ["param1"], - }, - ) - model = _get_input_model_from_mcp_tool(tool) - - # Create an instance to verify the model works - instance = model(param1="test", param2=42) - assert instance.param1 == "test" - assert instance.param2 == 42 - - # Test validation - with pytest.raises(ValidationError): # Missing required param1 - model(param2=42) - - -def test_get_input_model_from_mcp_tool_with_nested_object(): - """Test creation of input model from MCP tool with nested object property.""" - tool = types.Tool( - name="get_customer_detail", - description="Get customer details", - inputSchema={ - "type": "object", - "properties": { - "params": { - "type": "object", - "properties": {"customer_id": {"type": "integer"}}, - "required": ["customer_id"], - } +@pytest.mark.parametrize( + "test_id,input_schema,valid_data,expected_values,invalid_data,validation_check", + [ + # Basic types with required/optional fields + ( + "basic_types", + { + "type": "object", + "properties": {"param1": {"type": "string"}, "param2": {"type": "number"}}, + "required": ["param1"], }, - "required": ["params"], - }, - ) - model = _get_input_model_from_mcp_tool(tool) - - # Create an instance to verify the model works with nested objects - instance = model(params={"customer_id": 251}) - assert instance.params == {"customer_id": 251} - assert isinstance(instance.params, dict) - - # Verify model_dump produces the correct nested structure - dumped = instance.model_dump() - assert dumped == {"params": {"customer_id": 251}} - - -def test_get_input_model_from_mcp_tool_with_ref_schema(): - """Test creation of input model from MCP tool with $ref schema. - - This simulates a FastMCP tool that uses Pydantic models with $ref in the schema. - The schema should be resolved and nested objects should be preserved. - """ - # This is similar to what FastMCP generates when you have: - # async def get_customer_detail(params: CustomerIdParam) -> CustomerDetail - tool = types.Tool( - name="get_customer_detail", - description="Get customer details", - inputSchema={ - "type": "object", - "properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}}, - "required": ["params"], - "$defs": { - "CustomerIdParam": { - "type": "object", - "properties": {"customer_id": {"type": "integer"}}, - "required": ["customer_id"], + {"param1": "test", "param2": 42}, + {"param1": "test", "param2": 42}, + {"param2": 42}, # Missing required param1 + None, + ), + # Nested object + ( + "nested_object", + { + "type": "object", + "properties": { + "params": { + "type": "object", + "properties": {"customer_id": {"type": "integer"}}, + "required": ["customer_id"], + } + }, + "required": ["params"], + }, + {"params": {"customer_id": 251}}, + {"params.customer_id": 251}, + {"params": {}}, # Missing required customer_id + lambda instance: isinstance(instance.params, BaseModel), + ), + # $ref resolution + ( + "ref_schema", + { + "type": "object", + "properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}}, + "required": ["params"], + "$defs": { + "CustomerIdParam": { + "type": "object", + "properties": {"customer_id": {"type": "integer"}}, + "required": ["customer_id"], + } + }, + }, + {"params": {"customer_id": 251}}, + {"params.customer_id": 251}, + {"params": {}}, # Missing required customer_id + lambda instance: isinstance(instance.params, BaseModel), + ), + # Array of strings (typed) + ( + "array_of_strings", + { + "type": "object", + "properties": { + "tags": { + "type": "array", + "description": "List of tags", + "items": {"type": "string"}, + } + }, + "required": ["tags"], + }, + {"tags": ["tag1", "tag2", "tag3"]}, + {"tags": ["tag1", "tag2", "tag3"]}, + None, # No validation error test for this case + None, + ), + # Array of integers (typed) + ( + "array_of_integers", + { + "type": "object", + "properties": { + "numbers": { + "type": "array", + "description": "List of integers", + "items": {"type": "integer"}, + } + }, + "required": ["numbers"], + }, + {"numbers": [1, 2, 3]}, + {"numbers": [1, 2, 3]}, + None, + None, + ), + # Array of objects (complex nested) + ( + "array_of_objects", + { + "type": "object", + "properties": { + "users": { + "type": "array", + "description": "List of users", + "items": { + "type": "object", + "properties": { + "id": {"type": "integer", "description": "User ID"}, + "name": {"type": "string", "description": "User name"}, + }, + "required": ["id", "name"], + }, + } + }, + "required": ["users"], + }, + {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, + {"users[0].id": 1, "users[0].name": "Alice", "users[1].id": 2, "users[1].name": "Bob"}, + {"users": [{"id": 1}]}, # Missing required 'name' + lambda instance: all(isinstance(user, BaseModel) for user in instance.users), + ), + # Deeply nested objects (3+ levels) + ( + "deeply_nested", + { + "type": "object", + "properties": { + "query": { + "type": "object", + "properties": { + "filters": { + "type": "object", + "properties": { + "date_range": { + "type": "object", + "properties": { + "start": {"type": "string"}, + "end": {"type": "string"}, + }, + "required": ["start", "end"], + }, + "categories": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["date_range"], + } + }, + "required": ["filters"], + } + }, + "required": ["query"], + }, + { + "query": { + "filters": { + "date_range": {"start": "2024-01-01", "end": "2024-12-31"}, + "categories": ["tech", "science"], + } } }, - }, - ) - model = _get_input_model_from_mcp_tool(tool) - - # Create an instance to verify the model works with $ref schemas - instance = model(params={"customer_id": 251}) - assert instance.params == {"customer_id": 251} - assert isinstance(instance.params, dict) - - # Verify model_dump produces the correct nested structure - dumped = instance.model_dump() - assert dumped == {"params": {"customer_id": 251}} - - -def test_get_input_model_from_mcp_tool_with_simple_array(): - """Test array with simple items schema (items schema should be preserved in json_schema_extra).""" - tool = types.Tool( - name="simple_array_tool", - description="Tool with simple array", - inputSchema={ - "type": "object", - "properties": { - "tags": { - "type": "array", - "description": "List of tags", - "items": {"type": "string"}, # Simple string array + { + "query.filters.date_range.start": "2024-01-01", + "query.filters.date_range.end": "2024-12-31", + "query.filters.categories": ["tech", "science"], + }, + {"query": {"filters": {"date_range": {}}}}, # Missing required start and end + None, + ), + # Complex $ref with nested structure + ( + "ref_nested_structure", + { + "type": "object", + "properties": {"order": {"$ref": "#/$defs/OrderParams"}}, + "required": ["order"], + "$defs": { + "OrderParams": { + "type": "object", + "properties": { + "customer": {"$ref": "#/$defs/Customer"}, + "items": {"type": "array", "items": {"$ref": "#/$defs/OrderItem"}}, + }, + "required": ["customer", "items"], + }, + "Customer": { + "type": "object", + "properties": {"id": {"type": "integer"}, "email": {"type": "string"}}, + "required": ["id", "email"], + }, + "OrderItem": { + "type": "object", + "properties": {"product_id": {"type": "string"}, "quantity": {"type": "integer"}}, + "required": ["product_id", "quantity"], + }, + }, + }, + { + "order": { + "customer": {"id": 123, "email": "test@example.com"}, + "items": [{"product_id": "prod1", "quantity": 2}], } }, - "required": ["tags"], - }, - ) + { + "order.customer.id": 123, + "order.customer.email": "test@example.com", + "order.items[0].product_id": "prod1", + "order.items[0].quantity": 2, + }, + {"order": {"customer": {"id": 123}, "items": []}}, # Missing email + lambda instance: isinstance(instance.order.customer, BaseModel), + ), + # Mixed types (primitives, arrays, nested objects) + ( + "mixed_types", + { + "type": "object", + "properties": { + "simple_string": {"type": "string"}, + "simple_number": {"type": "integer"}, + "string_array": {"type": "array", "items": {"type": "string"}}, + "nested_config": { + "type": "object", + "properties": { + "enabled": {"type": "boolean"}, + "options": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["enabled"], + }, + }, + "required": ["simple_string", "nested_config"], + }, + { + "simple_string": "test", + "simple_number": 42, + "string_array": ["a", "b"], + "nested_config": {"enabled": True, "options": ["opt1", "opt2"]}, + }, + { + "simple_string": "test", + "simple_number": 42, + "string_array": ["a", "b"], + "nested_config.enabled": True, + "nested_config.options": ["opt1", "opt2"], + }, + None, + None, + ), + # Empty schema (no properties) + ( + "empty_schema", + {"type": "object", "properties": {}}, + {}, + {}, + None, + None, + ), + # All primitive types + ( + "all_primitives", + { + "type": "object", + "properties": { + "string_field": {"type": "string"}, + "integer_field": {"type": "integer"}, + "number_field": {"type": "number"}, + "boolean_field": {"type": "boolean"}, + }, + }, + {"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True}, + {"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True}, + None, + None, + ), + # Edge case: unresolvable $ref (fallback to dict) + ( + "unresolvable_ref", + { + "type": "object", + "properties": {"data": {"$ref": "#/$defs/NonExistent"}}, + "$defs": {}, + }, + {"data": {"key": "value"}}, + {"data": {"key": "value"}}, + None, + None, + ), + # Edge case: array without items schema (fallback to bare list) + ( + "array_no_items", + { + "type": "object", + "properties": {"items": {"type": "array"}}, + }, + {"items": [1, "two", 3.0]}, + {"items": [1, "two", 3.0]}, + None, + None, + ), + # Edge case: object without properties (fallback to dict) + ( + "object_no_properties", + { + "type": "object", + "properties": {"config": {"type": "object"}}, + }, + {"config": {"arbitrary": "data", "nested": {"key": "value"}}}, + {"config": {"arbitrary": "data", "nested": {"key": "value"}}}, + None, + None, + ), + ], +) +def test_get_input_model_from_mcp_tool_parametrized( + test_id, input_schema, valid_data, expected_values, invalid_data, validation_check +): + """Parametrized test for JSON schema to Pydantic model conversion. + + This test covers various edge cases including: + - Basic types with required/optional fields + - Nested objects + - $ref resolution + - Typed arrays (strings, integers, objects) + - Deeply nested structures + - Complex $ref with nested structures + - Mixed types + + To add a new test case, add a tuple to the parametrize decorator with: + - test_id: A descriptive name for the test case + - input_schema: The JSON schema (inputSchema dict) + - valid_data: Valid data to instantiate the model + - expected_values: Dict of expected values (supports dot notation for nested access) + - invalid_data: Invalid data to test validation errors (None to skip) + - validation_check: Optional callable to perform additional validation checks + """ + tool = types.Tool(name="test_tool", description="A test tool", inputSchema=input_schema) model = _get_input_model_from_mcp_tool(tool) - # Create an instance - instance = model(tags=["tag1", "tag2", "tag3"]) - assert instance.tags == ["tag1", "tag2", "tag3"] - - # Verify JSON schema still preserves items for simple types - json_schema = model.model_json_schema() - tags_property = json_schema["properties"]["tags"] - assert "items" in tags_property - assert tags_property["items"]["type"] == "string" + # Test valid data + instance = model(**valid_data) + + # Check expected values + for field_path, expected_value in expected_values.items(): + # Support dot notation and array indexing for nested access + current = instance + parts = field_path.replace("]", "").replace("[", ".").split(".") + for part in parts: + current = current[int(part)] if part.isdigit() else getattr(current, part) + assert current == expected_value, f"Field {field_path} = {current}, expected {expected_value}" + + # Run additional validation checks if provided + if validation_check: + assert validation_check(instance), f"Validation check failed for {test_id}" + + # Test invalid data if provided + if invalid_data is not None: + with pytest.raises(ValidationError): + model(**invalid_data) def test_get_input_model_from_mcp_prompt(): diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 7b280da123..6cb41f674b 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -193,7 +193,8 @@ async def process( # Create a message to start the conversation messages = [ChatMessage(role=Role.USER, text="test message")] - # Set up chat client to return a function call + # Set up chat client to return a function call, then a final response + # If terminate works correctly, only the first response should be consumed chat_client.responses = [ ChatResponse( messages=[ @@ -204,7 +205,8 @@ async def process( ], ) ] - ) + ), + ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="this should not be consumed")]), ] # Create the test function with the expected signature @@ -222,7 +224,11 @@ def test_function(text: str) -> str: # Verify that function was not called and only middleware executed assert execution_order == ["middleware_before", "middleware_after"] assert "function_called" not in execution_order - assert execution_order == ["middleware_before", "middleware_after"] + + # Verify the chat client was only called once (no extra LLM call after termination) + assert chat_client.call_count == 1 + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client.responses) == 1 async def test_function_middleware_with_post_termination(self, chat_client: "MockChatClient") -> None: """Test that function middleware can terminate execution after calling next().""" @@ -242,7 +248,8 @@ async def process( # Create a message to start the conversation messages = [ChatMessage(role=Role.USER, text="test message")] - # Set up chat client to return a function call + # Set up chat client to return a function call, then a final response + # If terminate works correctly, only the first response should be consumed chat_client.responses = [ ChatResponse( messages=[ @@ -253,7 +260,8 @@ async def process( ], ) ] - ) + ), + ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="this should not be consumed")]), ] # Create the test function with the expected signature @@ -273,6 +281,11 @@ def test_function(text: str) -> str: assert "function_called" in execution_order assert execution_order == ["middleware_before", "function_called", "middleware_after"] + # Verify the chat client was only called once (no extra LLM call after termination) + assert chat_client.call_count == 1 + # Verify the second response is still in the queue (wasn't consumed) + assert len(chat_client.responses) == 1 + async def test_function_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None: """Test function-based agent middleware with ChatAgent.""" execution_order: list[str] = [] diff --git a/python/packages/core/tests/core/test_observability.py b/python/packages/core/tests/core/test_observability.py index abdc5184be..8528295406 100644 --- a/python/packages/core/tests/core/test_observability.py +++ b/python/packages/core/tests/core/test_observability.py @@ -33,8 +33,8 @@ ChatMessageListTimestampFilter, OtelAttr, get_function_span, - use_agent_observability, - use_observability, + use_agent_instrumentation, + use_instrumentation, ) # region Test constants @@ -157,7 +157,7 @@ def test_start_span_with_tool_call_id(span_exporter: InMemorySpanExporter): assert span.attributes[OtelAttr.TOOL_TYPE] == "function" -# region Test use_observability decorator +# region Test use_instrumentation decorator def test_decorator_with_valid_class(): @@ -175,7 +175,7 @@ async def gen(): return gen() # Apply the decorator - decorated_class = use_observability(MockChatClient) + decorated_class = use_instrumentation(MockChatClient) assert hasattr(decorated_class, OPEN_TELEMETRY_CHAT_CLIENT_MARKER) @@ -187,7 +187,7 @@ class MockChatClient: # Apply the decorator - should not raise an error with pytest.raises(ChatClientInitializationError): - use_observability(MockChatClient) + use_instrumentation(MockChatClient) def test_decorator_with_partial_methods(): @@ -200,7 +200,7 @@ async def get_response(self, messages, **kwargs): return Mock() with pytest.raises(ChatClientInitializationError): - use_observability(MockChatClient) + use_instrumentation(MockChatClient) # region Test telemetry decorator with mock client @@ -235,7 +235,7 @@ async def _inner_get_streaming_response( @pytest.mark.parametrize("enable_sensitive_data", [True, False], indirect=True) async def test_chat_client_observability(mock_chat_client, span_exporter: InMemorySpanExporter, enable_sensitive_data): """Test that when diagnostics are enabled, telemetry is applied.""" - client = use_observability(mock_chat_client)() + client = use_instrumentation(mock_chat_client)() messages = [ChatMessage(role=Role.USER, text="Test message")] span_exporter.clear() @@ -258,8 +258,8 @@ async def test_chat_client_observability(mock_chat_client, span_exporter: InMemo async def test_chat_client_streaming_observability( mock_chat_client, span_exporter: InMemorySpanExporter, enable_sensitive_data ): - """Test streaming telemetry through the use_observability decorator.""" - client = use_observability(mock_chat_client)() + """Test streaming telemetry through the use_instrumentation decorator.""" + client = use_instrumentation(mock_chat_client)() messages = [ChatMessage(role=Role.USER, text="Test")] span_exporter.clear() # Collect all yielded updates @@ -282,7 +282,7 @@ async def test_chat_client_streaming_observability( async def test_chat_client_without_model_id_observability(mock_chat_client, span_exporter: InMemorySpanExporter): """Test telemetry shouldn't fail when the model_id is not provided for unknown reason.""" - client = use_observability(mock_chat_client)() + client = use_instrumentation(mock_chat_client)() messages = [ChatMessage(role=Role.USER, text="Test")] span_exporter.clear() response = await client.get_response(messages=messages) @@ -301,7 +301,7 @@ async def test_chat_client_streaming_without_model_id_observability( mock_chat_client, span_exporter: InMemorySpanExporter ): """Test streaming telemetry shouldn't fail when the model_id is not provided for unknown reason.""" - client = use_observability(mock_chat_client)() + client = use_instrumentation(mock_chat_client)() messages = [ChatMessage(role=Role.USER, text="Test")] span_exporter.clear() # Collect all yielded updates @@ -329,7 +329,7 @@ def test_prepend_user_agent_with_none_value(): assert AGENT_FRAMEWORK_USER_AGENT in str(result["User-Agent"]) -# region Test use_agent_observability decorator +# region Test use_agent_instrumentation decorator def test_agent_decorator_with_valid_class(): @@ -337,7 +337,7 @@ def test_agent_decorator_with_valid_class(): # Create a mock class with the required methods class MockChatClientAgent: - AGENT_SYSTEM_NAME = "test_agent_system" + AGENT_PROVIDER_NAME = "test_agent_system" def __init__(self): self.id = "test_agent_id" @@ -358,7 +358,7 @@ def get_new_thread(self) -> AgentThread: return AgentThread() # Apply the decorator - decorated_class = use_agent_observability(MockChatClientAgent) + decorated_class = use_agent_instrumentation(MockChatClientAgent) assert hasattr(decorated_class, OPEN_TELEMETRY_AGENT_MARKER) @@ -367,19 +367,19 @@ def test_agent_decorator_with_missing_methods(): """Test that agent decorator handles classes missing required methods gracefully.""" class MockAgent: - AGENT_SYSTEM_NAME = "test_agent_system" + AGENT_PROVIDER_NAME = "test_agent_system" # Apply the decorator - should not raise an error with pytest.raises(AgentInitializationError): - use_agent_observability(MockAgent) + use_agent_instrumentation(MockAgent) def test_agent_decorator_with_partial_methods(): """Test agent decorator when only one method is present.""" - from agent_framework.observability import use_agent_observability + from agent_framework.observability import use_agent_instrumentation class MockAgent: - AGENT_SYSTEM_NAME = "test_agent_system" + AGENT_PROVIDER_NAME = "test_agent_system" def __init__(self): self.id = "test_agent_id" @@ -390,7 +390,7 @@ async def run(self, messages=None, *, thread=None, **kwargs): return Mock() with pytest.raises(AgentInitializationError): - use_agent_observability(MockAgent) + use_agent_instrumentation(MockAgent) # region Test agent telemetry decorator with mock agent @@ -401,7 +401,7 @@ def mock_chat_agent(): """Create a mock chat client agent for testing.""" class MockChatClientAgent: - AGENT_SYSTEM_NAME = "test_agent_system" + AGENT_PROVIDER_NAME = "test_agent_system" def __init__(self): self.id = "test_agent_id" @@ -433,7 +433,7 @@ async def test_agent_instrumentation_enabled( ): """Test that when agent diagnostics are enabled, telemetry is applied.""" - agent = use_agent_observability(mock_chat_agent)() + agent = use_agent_instrumentation(mock_chat_agent)() span_exporter.clear() response = await agent.run("Test message") @@ -457,8 +457,8 @@ async def test_agent_instrumentation_enabled( async def test_agent_streaming_response_with_diagnostics_enabled_via_decorator( mock_chat_agent: AgentProtocol, span_exporter: InMemorySpanExporter, enable_sensitive_data ): - """Test agent streaming telemetry through the use_agent_observability decorator.""" - agent = use_agent_observability(mock_chat_agent)() + """Test agent streaming telemetry through the use_agent_instrumentation decorator.""" + agent = use_agent_instrumentation(mock_chat_agent)() span_exporter.clear() updates = [] async for update in agent.run_stream("Test message"): @@ -522,3 +522,393 @@ async def failing_function(param: str) -> str: exception_message = exception_event.attributes["exception.message"] assert isinstance(exception_message, str) assert "Function execution failed" in exception_message + + +# region Test OTEL environment variable parsing + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_with_grpc_endpoint(monkeypatch): + """Test _get_exporters_from_env with OTEL_EXPORTER_OTLP_ENDPOINT (gRPC).""" + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + + exporters = _get_exporters_from_env() + + # Should return 3 exporters (trace, metrics, logs) + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_with_http_endpoint(monkeypatch): + """Test _get_exporters_from_env with OTEL_EXPORTER_OTLP_ENDPOINT (HTTP).""" + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4318") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http") + + exporters = _get_exporters_from_env() + + # Should return 3 exporters (trace, metrics, logs) + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_with_individual_endpoints(monkeypatch): + """Test _get_exporters_from_env with individual signal endpoints.""" + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", "http://localhost:4318") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", "http://localhost:4319") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + + exporters = _get_exporters_from_env() + + # Should return 3 exporters (trace, metrics, logs) + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_with_headers(monkeypatch): + """Test _get_exporters_from_env with OTEL_EXPORTER_OTLP_HEADERS.""" + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_HEADERS", "key1=value1,key2=value2") + + exporters = _get_exporters_from_env() + + # Should return 3 exporters with headers + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_with_signal_specific_headers(monkeypatch): + """Test _get_exporters_from_env with signal-specific headers.""" + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_HEADERS", "trace-key=trace-value") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + + exporters = _get_exporters_from_env() + + # Should have at least the traces exporter + assert len(exporters) >= 1 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_without_env_vars(monkeypatch): + """Test _get_exporters_from_env returns empty list when no env vars set.""" + from agent_framework.observability import _get_exporters_from_env + + # Clear all OTEL env vars + for key in [ + "OTEL_EXPORTER_OTLP_ENDPOINT", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", + "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", + "OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", + ]: + monkeypatch.delenv(key, raising=False) + + exporters = _get_exporters_from_env() + + # Should return empty list + assert len(exporters) == 0 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_get_exporters_from_env_missing_grpc_dependency(monkeypatch): + """Test _get_exporters_from_env raises ImportError when gRPC exporters not installed.""" + + from agent_framework.observability import _get_exporters_from_env + + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + + # Mock the import to raise ImportError + original_import = __builtins__.__import__ + + def mock_import(name, *args, **kwargs): + if "opentelemetry.exporter.otlp.proto.grpc" in name: + raise ImportError("No module named 'opentelemetry.exporter.otlp.proto.grpc'") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(__builtins__, "__import__", mock_import) + + with pytest.raises(ImportError, match="opentelemetry-exporter-otlp-proto-grpc"): + _get_exporters_from_env() + + +# region Test create_resource + + +def test_create_resource_from_env(monkeypatch): + """Test create_resource reads OTEL environment variables.""" + from agent_framework.observability import create_resource + + monkeypatch.setenv("OTEL_SERVICE_NAME", "test-service") + monkeypatch.setenv("OTEL_SERVICE_VERSION", "1.0.0") + monkeypatch.setenv("OTEL_RESOURCE_ATTRIBUTES", "deployment.environment=production,host.name=server1") + + resource = create_resource() + + assert resource.attributes["service.name"] == "test-service" + assert resource.attributes["service.version"] == "1.0.0" + assert resource.attributes["deployment.environment"] == "production" + assert resource.attributes["host.name"] == "server1" + + +def test_create_resource_with_parameters_override_env(monkeypatch): + """Test create_resource parameters override environment variables.""" + from agent_framework.observability import create_resource + + monkeypatch.setenv("OTEL_SERVICE_NAME", "env-service") + monkeypatch.setenv("OTEL_SERVICE_VERSION", "0.1.0") + + resource = create_resource(service_name="param-service", service_version="2.0.0") + + # Parameters should override env vars + assert resource.attributes["service.name"] == "param-service" + assert resource.attributes["service.version"] == "2.0.0" + + +def test_create_resource_with_custom_attributes(monkeypatch): + """Test create_resource accepts custom attributes.""" + from agent_framework.observability import create_resource + + resource = create_resource(custom_attr="custom_value", another_attr=123) + + assert resource.attributes["custom_attr"] == "custom_value" + assert resource.attributes["another_attr"] == 123 + + +# region Test _create_otlp_exporters + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_create_otlp_exporters_grpc_with_single_endpoint(): + """Test _create_otlp_exporters creates gRPC exporters with single endpoint.""" + from agent_framework.observability import _create_otlp_exporters + + exporters = _create_otlp_exporters(endpoint="http://localhost:4317", protocol="grpc") + + # Should return 3 exporters (trace, metrics, logs) + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_create_otlp_exporters_http_with_single_endpoint(): + """Test _create_otlp_exporters creates HTTP exporters with single endpoint.""" + from agent_framework.observability import _create_otlp_exporters + + exporters = _create_otlp_exporters(endpoint="http://localhost:4318", protocol="http") + + # Should return 3 exporters (trace, metrics, logs) + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_create_otlp_exporters_with_individual_endpoints(): + """Test _create_otlp_exporters with individual signal endpoints.""" + from agent_framework.observability import _create_otlp_exporters + + exporters = _create_otlp_exporters( + protocol="grpc", + traces_endpoint="http://localhost:4317", + metrics_endpoint="http://localhost:4318", + logs_endpoint="http://localhost:4319", + ) + + # Should return 3 exporters + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_create_otlp_exporters_with_headers(): + """Test _create_otlp_exporters with headers.""" + from agent_framework.observability import _create_otlp_exporters + + exporters = _create_otlp_exporters( + endpoint="http://localhost:4317", protocol="grpc", headers={"Authorization": "Bearer token"} + ) + + # Should return 3 exporters with headers + assert len(exporters) == 3 + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_create_otlp_exporters_grpc_missing_dependency(): + """Test _create_otlp_exporters raises ImportError when gRPC exporters not installed.""" + import sys + from unittest.mock import patch + + from agent_framework.observability import _create_otlp_exporters + + # Mock the import to raise ImportError + with ( + patch.dict(sys.modules, {"opentelemetry.exporter.otlp.proto.grpc.trace_exporter": None}), + pytest.raises(ImportError, match="opentelemetry-exporter-otlp-proto-grpc"), + ): + _create_otlp_exporters(endpoint="http://localhost:4317", protocol="grpc") + + +# region Test configure_otel_providers with views + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_configure_otel_providers_with_views(monkeypatch): + """Test configure_otel_providers accepts views parameter.""" + from opentelemetry.sdk.metrics import View + from opentelemetry.sdk.metrics.view import DropAggregation + + from agent_framework.observability import configure_otel_providers + + # Clear all OTEL env vars + for key in [ + "OTEL_EXPORTER_OTLP_ENDPOINT", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", + "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", + "OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", + ]: + monkeypatch.delenv(key, raising=False) + + # Create a view that drops all metrics + views = [View(instrument_name="*", aggregation=DropAggregation())] + + # Should not raise an error + configure_otel_providers(views=views) + + +@pytest.mark.skipif( + True, + reason="Skipping OTLP exporter tests - optional dependency not installed by default", +) +def test_configure_otel_providers_without_views(monkeypatch): + """Test configure_otel_providers works without views parameter.""" + from agent_framework.observability import configure_otel_providers + + # Clear all OTEL env vars + for key in [ + "OTEL_EXPORTER_OTLP_ENDPOINT", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", + "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", + "OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", + ]: + monkeypatch.delenv(key, raising=False) + + # Should not raise an error with default empty views + configure_otel_providers() + + +# region Test console exporters opt-in + + +def test_console_exporters_opt_in_false(monkeypatch): + """Test console exporters are not added when ENABLE_CONSOLE_EXPORTERS is false.""" + from agent_framework.observability import ObservabilitySettings + + monkeypatch.setenv("ENABLE_CONSOLE_EXPORTERS", "false") + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + + settings = ObservabilitySettings(env_file_path="test.env") + assert settings.enable_console_exporters is False + + +def test_console_exporters_opt_in_true(monkeypatch): + """Test console exporters are added when ENABLE_CONSOLE_EXPORTERS is true.""" + from agent_framework.observability import ObservabilitySettings + + monkeypatch.setenv("ENABLE_CONSOLE_EXPORTERS", "true") + + settings = ObservabilitySettings(env_file_path="test.env") + assert settings.enable_console_exporters is True + + +def test_console_exporters_default_false(monkeypatch): + """Test console exporters default to False when not set.""" + from agent_framework.observability import ObservabilitySettings + + monkeypatch.delenv("ENABLE_CONSOLE_EXPORTERS", raising=False) + + settings = ObservabilitySettings(env_file_path="test.env") + assert settings.enable_console_exporters is False + + +# region Test _parse_headers helper + + +def test_parse_headers_valid(): + """Test _parse_headers with valid header string.""" + from agent_framework.observability import _parse_headers + + headers = _parse_headers("key1=value1,key2=value2") + assert headers == {"key1": "value1", "key2": "value2"} + + +def test_parse_headers_with_spaces(): + """Test _parse_headers handles spaces around keys and values.""" + from agent_framework.observability import _parse_headers + + headers = _parse_headers("key1 = value1 , key2 = value2 ") + assert headers == {"key1": "value1", "key2": "value2"} + + +def test_parse_headers_empty_string(): + """Test _parse_headers with empty string.""" + from agent_framework.observability import _parse_headers + + headers = _parse_headers("") + assert headers == {} + + +def test_parse_headers_invalid_format(): + """Test _parse_headers ignores invalid pairs.""" + from agent_framework.observability import _parse_headers + + headers = _parse_headers("key1=value1,invalid,key2=value2") + # Should only include valid pairs + assert headers == {"key1": "value1", "key2": "value2"} diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index c1cc0f119b..88c34dc3e8 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Any +from typing import Annotated, Any, Literal from unittest.mock import Mock import pytest @@ -14,7 +14,7 @@ ToolProtocol, ai_function, ) -from agent_framework._tools import _parse_inputs +from agent_framework._tools import _parse_annotation, _parse_inputs from agent_framework.exceptions import ToolException from agent_framework.observability import OtelAttr @@ -128,6 +128,95 @@ def test_tool(self, x: int, y: int) -> int: assert test_tool(1, 2) == 3 +def test_ai_function_with_literal_type_parameter(): + """Test ai_function decorator with Literal type parameter (issue #2891).""" + + @ai_function + def search_flows(category: Literal["Data", "Security", "Network"], issue: str) -> str: + """Search flows by category.""" + return f"{category}: {issue}" + + assert isinstance(search_flows, AIFunction) + schema = search_flows.parameters() + assert schema == { + "properties": { + "category": {"enum": ["Data", "Security", "Network"], "title": "Category", "type": "string"}, + "issue": {"title": "Issue", "type": "string"}, + }, + "required": ["category", "issue"], + "title": "search_flows_input", + "type": "object", + } + # Verify invocation works + assert search_flows("Data", "test issue") == "Data: test issue" + + +def test_ai_function_with_literal_type_in_class_method(): + """Test ai_function decorator with Literal type parameter in a class method (issue #2891).""" + + class MyTools: + @ai_function + def search_flows(self, category: Literal["Data", "Security", "Network"], issue: str) -> str: + """Search flows by category.""" + return f"{category}: {issue}" + + tools = MyTools() + search_tool = tools.search_flows + assert isinstance(search_tool, AIFunction) + schema = search_tool.parameters() + assert schema == { + "properties": { + "category": {"enum": ["Data", "Security", "Network"], "title": "Category", "type": "string"}, + "issue": {"title": "Issue", "type": "string"}, + }, + "required": ["category", "issue"], + "title": "search_flows_input", + "type": "object", + } + # Verify invocation works + assert search_tool("Security", "test issue") == "Security: test issue" + + +def test_ai_function_with_literal_int_type(): + """Test ai_function decorator with Literal int type parameter.""" + + @ai_function + def set_priority(priority: Literal[1, 2, 3], task: str) -> str: + """Set priority for a task.""" + return f"Priority {priority}: {task}" + + assert isinstance(set_priority, AIFunction) + schema = set_priority.parameters() + assert schema == { + "properties": { + "priority": {"enum": [1, 2, 3], "title": "Priority", "type": "integer"}, + "task": {"title": "Task", "type": "string"}, + }, + "required": ["priority", "task"], + "title": "set_priority_input", + "type": "object", + } + assert set_priority(1, "important task") == "Priority 1: important task" + + +def test_ai_function_with_literal_and_annotated(): + """Test ai_function decorator with Literal type combined with Annotated for description.""" + + @ai_function + def categorize( + category: Annotated[Literal["A", "B", "C"], "The category to assign"], + name: str, + ) -> str: + """Categorize an item.""" + return f"{category}: {name}" + + assert isinstance(categorize, AIFunction) + schema = categorize.parameters() + # Literal type inside Annotated should preserve enum values + assert schema["properties"]["category"]["enum"] == ["A", "B", "C"] + assert categorize("A", "test") == "A: test" + + async def test_ai_function_decorator_shared_state(): """Test that decorated methods maintain shared state across multiple calls and tool usage.""" @@ -1334,3 +1423,104 @@ async def mock_get_streaming_response(self, messages, **kwargs): assert updates[2].role == Role.ASSISTANT assert len(updates[2].contents) == 2 assert all(isinstance(c, FunctionApprovalRequestContent) for c in updates[2].contents) + + +async def test_ai_function_with_kwargs_injection(): + """Test that ai_function correctly handles kwargs injection and hides them from schema.""" + + @ai_function + def tool_with_kwargs(x: int, **kwargs: Any) -> str: + """A tool that accepts kwargs.""" + user_id = kwargs.get("user_id", "unknown") + return f"x={x}, user={user_id}" + + # Verify schema does not include kwargs + assert tool_with_kwargs.parameters() == { + "properties": {"x": {"title": "X", "type": "integer"}}, + "required": ["x"], + "title": "tool_with_kwargs_input", + "type": "object", + } + + # Verify direct invocation works + assert tool_with_kwargs(1, user_id="user1") == "x=1, user=user1" + + # Verify invoke works with injected args + result = await tool_with_kwargs.invoke( + arguments=tool_with_kwargs.input_model(x=5), + user_id="user2", + ) + assert result == "x=5, user=user2" + + # Verify invoke works without injected args (uses default) + result_default = await tool_with_kwargs.invoke( + arguments=tool_with_kwargs.input_model(x=10), + ) + assert result_default == "x=10, user=unknown" + + +# region _parse_annotation tests + + +def test_parse_annotation_with_literal_type(): + """Test that _parse_annotation returns Literal types unchanged (issue #2891).""" + from typing import get_args, get_origin + + # Literal with string values + literal_annotation = Literal["Data", "Security", "Network"] + result = _parse_annotation(literal_annotation) + assert result is literal_annotation + assert get_origin(result) is Literal + assert get_args(result) == ("Data", "Security", "Network") + + +def test_parse_annotation_with_literal_int_type(): + """Test that _parse_annotation returns Literal int types unchanged.""" + from typing import get_args, get_origin + + literal_annotation = Literal[1, 2, 3] + result = _parse_annotation(literal_annotation) + assert result is literal_annotation + assert get_origin(result) is Literal + assert get_args(result) == (1, 2, 3) + + +def test_parse_annotation_with_literal_bool_type(): + """Test that _parse_annotation returns Literal bool types unchanged.""" + from typing import get_args, get_origin + + literal_annotation = Literal[True, False] + result = _parse_annotation(literal_annotation) + assert result is literal_annotation + assert get_origin(result) is Literal + assert get_args(result) == (True, False) + + +def test_parse_annotation_with_simple_types(): + """Test that _parse_annotation returns simple types unchanged.""" + assert _parse_annotation(str) is str + assert _parse_annotation(int) is int + assert _parse_annotation(float) is float + assert _parse_annotation(bool) is bool + + +def test_parse_annotation_with_annotated_and_literal(): + """Test that Annotated[Literal[...], description] works correctly.""" + from typing import get_args, get_origin + + # When Literal is inside Annotated, it should still be preserved + annotated_literal = Annotated[Literal["A", "B", "C"], "The category"] + result = _parse_annotation(annotated_literal) + + # The Annotated type should be preserved + origin = get_origin(result) + assert origin is Annotated + + args = get_args(result) + # First arg is the Literal type + literal_type = args[0] + assert get_origin(literal_type) is Literal + assert get_args(literal_type) == ("A", "B", "C") + + +# endregion diff --git a/python/packages/core/tests/openai/test_openai_responses_client.py b/python/packages/core/tests/openai/test_openai_responses_client.py index f2f9004d55..3863f4701a 100644 --- a/python/packages/core/tests/openai/test_openai_responses_client.py +++ b/python/packages/core/tests/openai/test_openai_responses_client.py @@ -832,7 +832,7 @@ def test_create_streaming_response_content_with_mcp_approval_request() -> None: assert fa.function_call.name == "do_stream_action" -@pytest.mark.parametrize("enable_otel", [False], indirect=True) +@pytest.mark.parametrize("enable_instrumentation", [False], indirect=True) @pytest.mark.parametrize("enable_sensitive_data", [False], indirect=True) async def test_end_to_end_mcp_approval_flow(span_exporter) -> None: """End-to-end mocked test: @@ -993,6 +993,110 @@ def test_streaming_response_basic_structure() -> None: assert response.raw_representation is mock_event +def test_streaming_annotation_added_with_file_path() -> None: + """Test streaming annotation added event with file_path type extracts HostedFileContent.""" + client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + + mock_event = MagicMock() + mock_event.type = "response.output_text.annotation.added" + mock_event.annotation_index = 0 + mock_event.annotation = { + "type": "file_path", + "file_id": "file-abc123", + "index": 42, + } + + response = client._create_streaming_response_content(mock_event, chat_options, function_call_ids) + + assert len(response.contents) == 1 + content = response.contents[0] + assert isinstance(content, HostedFileContent) + assert content.file_id == "file-abc123" + assert content.additional_properties is not None + assert content.additional_properties.get("annotation_index") == 0 + assert content.additional_properties.get("index") == 42 + + +def test_streaming_annotation_added_with_file_citation() -> None: + """Test streaming annotation added event with file_citation type extracts HostedFileContent.""" + client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + + mock_event = MagicMock() + mock_event.type = "response.output_text.annotation.added" + mock_event.annotation_index = 1 + mock_event.annotation = { + "type": "file_citation", + "file_id": "file-xyz789", + "filename": "sample.txt", + "index": 15, + } + + response = client._create_streaming_response_content(mock_event, chat_options, function_call_ids) + + assert len(response.contents) == 1 + content = response.contents[0] + assert isinstance(content, HostedFileContent) + assert content.file_id == "file-xyz789" + assert content.additional_properties is not None + assert content.additional_properties.get("filename") == "sample.txt" + assert content.additional_properties.get("index") == 15 + + +def test_streaming_annotation_added_with_container_file_citation() -> None: + """Test streaming annotation added event with container_file_citation type.""" + client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + + mock_event = MagicMock() + mock_event.type = "response.output_text.annotation.added" + mock_event.annotation_index = 2 + mock_event.annotation = { + "type": "container_file_citation", + "file_id": "file-container123", + "container_id": "container-456", + "filename": "data.csv", + "start_index": 10, + "end_index": 50, + } + + response = client._create_streaming_response_content(mock_event, chat_options, function_call_ids) + + assert len(response.contents) == 1 + content = response.contents[0] + assert isinstance(content, HostedFileContent) + assert content.file_id == "file-container123" + assert content.additional_properties is not None + assert content.additional_properties.get("container_id") == "container-456" + assert content.additional_properties.get("filename") == "data.csv" + assert content.additional_properties.get("start_index") == 10 + assert content.additional_properties.get("end_index") == 50 + + +def test_streaming_annotation_added_with_unknown_type() -> None: + """Test streaming annotation added event with unknown type is ignored.""" + client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + + mock_event = MagicMock() + mock_event.type = "response.output_text.annotation.added" + mock_event.annotation_index = 0 + mock_event.annotation = { + "type": "url_citation", + "url": "https://example.com", + } + + response = client._create_streaming_response_content(mock_event, chat_options, function_call_ids) + + # url_citation should not produce HostedFileContent + assert len(response.contents) == 0 + + def test_service_response_exception_includes_original_error_details() -> None: """Test that ServiceResponseException messages include original error details in the new format.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") diff --git a/python/packages/core/tests/test_observability_datetime.py b/python/packages/core/tests/test_observability_datetime.py index 05efdc1a5e..6ad3d77e1a 100644 --- a/python/packages/core/tests/test_observability_datetime.py +++ b/python/packages/core/tests/test_observability_datetime.py @@ -22,5 +22,5 @@ def test_datetime_in_tool_results() -> None: result = _to_otel_part(content) parsed = json.loads(result["response"]) - # Datetime should be converted to string - assert isinstance(parsed["timestamp"], str) + # Datetime should be converted to string in the result field + assert isinstance(parsed["result"]["timestamp"], str) diff --git a/python/packages/core/tests/workflow/test_concurrent.py b/python/packages/core/tests/workflow/test_concurrent.py index db70be3f38..57810b8f59 100644 --- a/python/packages/core/tests/workflow/test_concurrent.py +++ b/python/packages/core/tests/workflow/test_concurrent.py @@ -3,6 +3,7 @@ from typing import Any, cast import pytest +from typing_extensions import Never from agent_framework import ( AgentExecutorRequest, @@ -52,6 +53,55 @@ def test_concurrent_builder_rejects_duplicate_executors() -> None: ConcurrentBuilder().participants([a, b]) +def test_concurrent_builder_rejects_duplicate_executors_from_factories() -> None: + """Test that duplicate executor IDs from factories are detected at build time.""" + + def create_dup1() -> Executor: + return _FakeAgentExec("dup", "A") + + def create_dup2() -> Executor: + return _FakeAgentExec("dup", "B") # same executor id + + builder = ConcurrentBuilder().register_participants([create_dup1, create_dup2]) + with pytest.raises(ValueError, match="Duplicate executor ID 'dup' detected in workflow."): + builder.build() + + +def test_concurrent_builder_rejects_mixed_participants_and_factories() -> None: + """Test that mixing .participants() and .register_participants() raises an error.""" + # Case 1: participants first, then register_participants + with pytest.raises(ValueError, match="Cannot mix .participants"): + ( + ConcurrentBuilder() + .participants([_FakeAgentExec("a", "A")]) + .register_participants([lambda: _FakeAgentExec("b", "B")]) + ) + + # Case 2: register_participants first, then participants + with pytest.raises(ValueError, match="Cannot mix .participants"): + ( + ConcurrentBuilder() + .register_participants([lambda: _FakeAgentExec("a", "A")]) + .participants([_FakeAgentExec("b", "B")]) + ) + + +def test_concurrent_builder_rejects_multiple_calls_to_participants() -> None: + """Test that multiple calls to .participants() raises an error.""" + with pytest.raises(ValueError, match=r"participants\(\) has already been called"): + (ConcurrentBuilder().participants([_FakeAgentExec("a", "A")]).participants([_FakeAgentExec("b", "B")])) + + +def test_concurrent_builder_rejects_multiple_calls_to_register_participants() -> None: + """Test that multiple calls to .register_participants() raises an error.""" + with pytest.raises(ValueError, match=r"register_participants\(\) has already been called"): + ( + ConcurrentBuilder() + .register_participants([lambda: _FakeAgentExec("a", "A")]) + .register_participants([lambda: _FakeAgentExec("b", "B")]) + ) + + async def test_concurrent_default_aggregator_emits_single_user_and_assistants() -> None: # Three synthetic agent executors e1 = _FakeAgentExec("agentA", "Alpha") @@ -159,6 +209,138 @@ def summarize(results: list[AgentExecutorResponse]) -> str: # type: ignore[over assert aggregator.id == "summarize" +async def test_concurrent_with_aggregator_executor_instance() -> None: + """Test with_aggregator using an Executor instance (not factory).""" + + class CustomAggregator(Executor): + @handler + async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: + texts: list[str] = [] + for r in results: + msgs: list[ChatMessage] = r.agent_run_response.messages + texts.append(msgs[-1].text if msgs else "") + await ctx.yield_output(" & ".join(sorted(texts))) + + e1 = _FakeAgentExec("agentA", "One") + e2 = _FakeAgentExec("agentB", "Two") + + aggregator_instance = CustomAggregator(id="instance_aggregator") + wf = ConcurrentBuilder().participants([e1, e2]).with_aggregator(aggregator_instance).build() + + completed = False + output: str | None = None + async for ev in wf.run_stream("prompt: instance test"): + if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: + completed = True + elif isinstance(ev, WorkflowOutputEvent): + output = cast(str, ev.data) + if completed and output is not None: + break + + assert completed + assert output is not None + assert isinstance(output, str) + assert output == "One & Two" + + +async def test_concurrent_with_aggregator_executor_factory() -> None: + """Test with_aggregator using an Executor factory.""" + + class CustomAggregator(Executor): + @handler + async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: + texts: list[str] = [] + for r in results: + msgs: list[ChatMessage] = r.agent_run_response.messages + texts.append(msgs[-1].text if msgs else "") + await ctx.yield_output(" | ".join(sorted(texts))) + + e1 = _FakeAgentExec("agentA", "One") + e2 = _FakeAgentExec("agentB", "Two") + + wf = ( + ConcurrentBuilder() + .participants([e1, e2]) + .register_aggregator(lambda: CustomAggregator(id="custom_aggregator")) + .build() + ) + + completed = False + output: str | None = None + async for ev in wf.run_stream("prompt: factory test"): + if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: + completed = True + elif isinstance(ev, WorkflowOutputEvent): + output = cast(str, ev.data) + if completed and output is not None: + break + + assert completed + assert output is not None + assert isinstance(output, str) + assert output == "One | Two" + + +async def test_concurrent_with_aggregator_executor_factory_with_default_id() -> None: + """Test with_aggregator using an Executor class directly as factory (with default __init__ parameters).""" + + class CustomAggregator(Executor): + def __init__(self, id: str = "default_aggregator") -> None: + super().__init__(id) + + @handler + async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: + texts: list[str] = [] + for r in results: + msgs: list[ChatMessage] = r.agent_run_response.messages + texts.append(msgs[-1].text if msgs else "") + await ctx.yield_output(" | ".join(sorted(texts))) + + e1 = _FakeAgentExec("agentA", "One") + e2 = _FakeAgentExec("agentB", "Two") + + wf = ConcurrentBuilder().participants([e1, e2]).register_aggregator(CustomAggregator).build() + + completed = False + output: str | None = None + async for ev in wf.run_stream("prompt: factory test"): + if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: + completed = True + elif isinstance(ev, WorkflowOutputEvent): + output = cast(str, ev.data) + if completed and output is not None: + break + + assert completed + assert output is not None + assert isinstance(output, str) + assert output == "One | Two" + + +def test_concurrent_builder_rejects_multiple_calls_to_with_aggregator() -> None: + """Test that multiple calls to .with_aggregator() raises an error.""" + + def summarize(results: list[AgentExecutorResponse]) -> str: # type: ignore[override] + return str(len(results)) + + with pytest.raises(ValueError, match=r"with_aggregator\(\) has already been called"): + (ConcurrentBuilder().with_aggregator(summarize).with_aggregator(summarize)) + + +def test_concurrent_builder_rejects_multiple_calls_to_register_aggregator() -> None: + """Test that multiple calls to .register_aggregator() raises an error.""" + + class CustomAggregator(Executor): + pass + + with pytest.raises(ValueError, match=r"register_aggregator\(\) has already been called"): + ( + ConcurrentBuilder() + .register_aggregator(lambda: CustomAggregator(id="agg1")) + .register_aggregator(lambda: CustomAggregator(id="agg2")) + ) + + async def test_concurrent_checkpoint_resume_round_trip() -> None: storage = InMemoryCheckpointStorage() @@ -278,3 +460,92 @@ async def test_concurrent_checkpoint_runtime_overrides_buildtime() -> None: assert len(runtime_checkpoints) > 0, "Runtime storage should have checkpoints" assert len(buildtime_checkpoints) == 0, "Build-time storage should have no checkpoints when overridden" + + +def test_concurrent_builder_rejects_empty_participant_factories() -> None: + with pytest.raises(ValueError): + ConcurrentBuilder().register_participants([]) + + +async def test_concurrent_builder_reusable_after_build_with_participants() -> None: + """Test that the builder can be reused to build multiple identical workflows with participants().""" + e1 = _FakeAgentExec("agentA", "One") + e2 = _FakeAgentExec("agentB", "Two") + + builder = ConcurrentBuilder().participants([e1, e2]) + + builder.build() + + assert builder._participants[0] is e1 # type: ignore + assert builder._participants[1] is e2 # type: ignore + assert builder._participant_factories == [] # type: ignore + + +async def test_concurrent_builder_reusable_after_build_with_factories() -> None: + """Test that the builder can be reused to build multiple workflows with register_participants().""" + call_count = 0 + + def create_agent_executor_a() -> Executor: + nonlocal call_count + call_count += 1 + return _FakeAgentExec("agentA", "One") + + def create_agent_executor_b() -> Executor: + nonlocal call_count + call_count += 1 + return _FakeAgentExec("agentB", "Two") + + builder = ConcurrentBuilder().register_participants([create_agent_executor_a, create_agent_executor_b]) + + # Build the first workflow + wf1 = builder.build() + + assert builder._participants == [] # type: ignore + assert len(builder._participant_factories) == 2 # type: ignore + assert call_count == 2 + + # Build the second workflow + wf2 = builder.build() + assert call_count == 4 + + # Verify that the two workflows have different executor instances + assert wf1.executors["agentA"] is not wf2.executors["agentA"] + assert wf1.executors["agentB"] is not wf2.executors["agentB"] + + +async def test_concurrent_with_register_participants() -> None: + """Test workflow creation using register_participants with factories.""" + + def create_agent1() -> Executor: + return _FakeAgentExec("agentA", "Alpha") + + def create_agent2() -> Executor: + return _FakeAgentExec("agentB", "Beta") + + def create_agent3() -> Executor: + return _FakeAgentExec("agentC", "Gamma") + + wf = ConcurrentBuilder().register_participants([create_agent1, create_agent2, create_agent3]).build() + + completed = False + output: list[ChatMessage] | None = None + async for ev in wf.run_stream("test prompt"): + if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: + completed = True + elif isinstance(ev, WorkflowOutputEvent): + output = cast(list[ChatMessage], ev.data) + if completed and output is not None: + break + + assert completed + assert output is not None + messages: list[ChatMessage] = output + + # Expect one user message + one assistant message per participant + assert len(messages) == 1 + 3 + assert messages[0].role == Role.USER + assert "test prompt" in messages[0].text + + assistant_texts = {m.text for m in messages[1:]} + assert assistant_texts == {"Alpha", "Beta", "Gamma"} + assert all(m.role == Role.ASSISTANT for m in messages[1:]) diff --git a/python/packages/core/tests/workflow/test_handoff.py b/python/packages/core/tests/workflow/test_handoff.py index 0ceccfaf15..d0d5092323 100644 --- a/python/packages/core/tests/workflow/test_handoff.py +++ b/python/packages/core/tests/workflow/test_handoff.py @@ -25,7 +25,12 @@ from agent_framework._mcp import MCPTool from agent_framework._workflows import AgentRunEvent from agent_framework._workflows import _handoff as handoff_module # type: ignore -from agent_framework._workflows._handoff import _clone_chat_agent # type: ignore[reportPrivateUsage] +from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value +from agent_framework._workflows._handoff import ( + _clone_chat_agent, # type: ignore[reportPrivateUsage] + _ConversationWithUserInput, + _UserInputGateway, +) from agent_framework._workflows._workflow_builder import WorkflowBuilder @@ -218,7 +223,7 @@ async def test_handoff_preserves_complex_additional_properties(complex_metadata: workflow = ( HandoffBuilder(participants=[triage, specialist]) - .set_coordinator("triage") + .set_coordinator(triage) .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role == Role.USER) >= 2) .build() ) @@ -281,7 +286,7 @@ async def test_tool_call_handoff_detection_with_text_hint(): triage = _RecordingAgent(name="triage", handoff_to="specialist", text_handoff=True) specialist = _RecordingAgent(name="specialist") - workflow = HandoffBuilder(participants=[triage, specialist]).set_coordinator("triage").build() + workflow = HandoffBuilder(participants=[triage, specialist]).set_coordinator(triage).build() await _drain(workflow.run_stream("Package arrived broken")) @@ -296,7 +301,7 @@ async def test_autonomous_interaction_mode_yields_output_without_user_request(): workflow = ( HandoffBuilder(participants=[triage, specialist]) - .set_coordinator("triage") + .set_coordinator(triage) .with_interaction_mode("autonomous", autonomous_turn_limit=1) .build() ) @@ -428,13 +433,13 @@ def test_build_fails_without_coordinator(): triage = _RecordingAgent(name="triage") specialist = _RecordingAgent(name="specialist") - with pytest.raises(ValueError, match="coordinator must be defined before build"): + with pytest.raises(ValueError, match=r"Must call set_coordinator\(...\) before building the workflow."): HandoffBuilder(participants=[triage, specialist]).build() def test_build_fails_without_participants(): """Verify that build() raises ValueError when no participants are provided.""" - with pytest.raises(ValueError, match="No participants provided"): + with pytest.raises(ValueError, match="No participants or participant_factories have been configured."): HandoffBuilder().build() @@ -605,7 +610,7 @@ async def test_return_to_previous_enabled(): workflow = ( HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator("triage") + .set_coordinator(triage) .enable_return_to_previous(True) .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) .build() @@ -638,7 +643,7 @@ def test_handoff_builder_sets_start_executor_once(monkeypatch: pytest.MonkeyPatc workflow = ( HandoffBuilder(participants=[coordinator, specialist]) - .set_coordinator("coordinator") + .set_coordinator(coordinator) .with_termination_condition(lambda conv: len(conv) > 0) .build() ) @@ -698,7 +703,7 @@ async def test_handoff_builder_with_request_info(): # Build workflow with request info enabled workflow = ( HandoffBuilder(participants=[coordinator, specialist]) - .set_coordinator("coordinator") + .set_coordinator(coordinator) .with_termination_condition(lambda conv: len([m for m in conv if m.role == Role.USER]) >= 1) .with_request_info() .build() @@ -775,3 +780,893 @@ async def test_return_to_previous_state_serialization(): # Verify current_agent_id was restored assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage] + + +# region Participant Factory Tests + + +def test_handoff_builder_rejects_empty_participant_factories(): + """Test that HandoffBuilder rejects empty participant_factories dictionary.""" + # Empty factories are rejected immediately when calling participant_factories() + with pytest.raises(ValueError, match=r"participant_factories cannot be empty"): + HandoffBuilder().participant_factories({}) + + with pytest.raises(ValueError, match=r"No participants or participant_factories have been configured"): + HandoffBuilder(participant_factories={}).build() + + +def test_handoff_builder_rejects_mixing_participants_and_factories(): + """Test that mixing participants and participant_factories in __init__ raises an error.""" + triage = _RecordingAgent(name="triage") + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder(participants=[triage], participant_factories={"triage": lambda: triage}) + + +def test_handoff_builder_rejects_mixing_participants_and_participant_factories_methods(): + """Test that mixing .participants() and .participant_factories() raises an error.""" + triage = _RecordingAgent(name="triage") + + # Case 1: participants first, then participant_factories + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder(participants=[triage]).participant_factories({ + "specialist": lambda: _RecordingAgent(name="specialist") + }) + + # Case 2: participant_factories first, then participants + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder(participant_factories={"triage": lambda: triage}).participants([ + _RecordingAgent(name="specialist") + ]) + + # Case 3: participants(), then participant_factories() + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder().participants([triage]).participant_factories({ + "specialist": lambda: _RecordingAgent(name="specialist") + }) + + # Case 4: participant_factories(), then participants() + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder().participant_factories({"triage": lambda: triage}).participants([ + _RecordingAgent(name="specialist") + ]) + + # Case 5: mix during initialization + with pytest.raises(ValueError, match="Cannot mix .participants"): + HandoffBuilder( + participants=[triage], participant_factories={"specialist": lambda: _RecordingAgent(name="specialist")} + ) + + +def test_handoff_builder_rejects_multiple_calls_to_participant_factories(): + """Test that multiple calls to .participant_factories() raises an error.""" + with pytest.raises(ValueError, match=r"participant_factories\(\) has already been called"): + ( + HandoffBuilder() + .participant_factories({"agent1": lambda: _RecordingAgent(name="agent1")}) + .participant_factories({"agent2": lambda: _RecordingAgent(name="agent2")}) + ) + + +def test_handoff_builder_rejects_multiple_calls_to_participants(): + """Test that multiple calls to .participants() raises an error.""" + with pytest.raises(ValueError, match="participants have already been assigned"): + (HandoffBuilder().participants([_RecordingAgent(name="agent1")]).participants([_RecordingAgent(name="agent2")])) + + +def test_handoff_builder_rejects_duplicate_factories(): + """Test that multiple calls to participant_factories are rejected.""" + factories = { + "triage": lambda: _RecordingAgent(name="triage"), + "specialist": lambda: _RecordingAgent(name="specialist"), + } + + # Multiple calls to participant_factories should fail + builder = HandoffBuilder(participant_factories=factories) + with pytest.raises(ValueError, match=r"participant_factories\(\) has already been called"): + builder.participant_factories({"triage": lambda: _RecordingAgent(name="triage2")}) + + +def test_handoff_builder_rejects_instance_coordinator_with_factories(): + """Test that using an agent instance for set_coordinator when using factories raises an error.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + # Create an agent instance + coordinator_instance = _RecordingAgent(name="coordinator") + + with pytest.raises(ValueError, match=r"Call participants\(\.\.\.\) before coordinator\(\.\.\.\)"): + ( + HandoffBuilder( + participant_factories={"triage": create_triage, "specialist": create_specialist} + ).set_coordinator(coordinator_instance) # Instance, not factory name + ) + + +def test_handoff_builder_rejects_factory_name_coordinator_with_instances(): + """Test that using a factory name for set_coordinator when using instances raises an error.""" + triage = _RecordingAgent(name="triage") + specialist = _RecordingAgent(name="specialist") + + with pytest.raises( + ValueError, match="coordinator factory name 'triage' is not part of the participant_factories list" + ): + ( + HandoffBuilder(participants=[triage, specialist]).set_coordinator( + "triage" + ) # String factory name, not instance + ) + + +def test_handoff_builder_rejects_mixed_types_in_add_handoff_source(): + """Test that add_handoff rejects factory name source with instance-based participants.""" + triage = _RecordingAgent(name="triage") + specialist = _RecordingAgent(name="specialist") + + with pytest.raises(TypeError, match="Cannot mix factory names \\(str\\) and AgentProtocol/Executor instances"): + ( + HandoffBuilder(participants=[triage, specialist]) + .set_coordinator(triage) + .add_handoff("triage", specialist) # String source with instance participants + ) + + +def test_handoff_builder_accepts_all_factory_names_in_add_handoff(): + """Test that add_handoff accepts all factory names when using participant_factories.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + def create_specialist_a() -> _RecordingAgent: + return _RecordingAgent(name="specialist_a") + + def create_specialist_b() -> _RecordingAgent: + return _RecordingAgent(name="specialist_b") + + # This should work - all strings with participant_factories + builder = ( + HandoffBuilder( + participant_factories={ + "triage": create_triage, + "specialist_a": create_specialist_a, + "specialist_b": create_specialist_b, + } + ) + .set_coordinator("triage") + .add_handoff("triage", ["specialist_a", "specialist_b"]) + ) + + workflow = builder.build() + assert "triage" in workflow.executors + assert "specialist_a" in workflow.executors + assert "specialist_b" in workflow.executors + + +def test_handoff_builder_accepts_all_instances_in_add_handoff(): + """Test that add_handoff accepts all instances when using participants.""" + triage = _RecordingAgent(name="triage", handoff_to="specialist_a") + specialist_a = _RecordingAgent(name="specialist_a") + specialist_b = _RecordingAgent(name="specialist_b") + + # This should work - all instances with participants + builder = ( + HandoffBuilder(participants=[triage, specialist_a, specialist_b]) + .set_coordinator(triage) + .add_handoff(triage, [specialist_a, specialist_b]) + ) + + workflow = builder.build() + assert "triage" in workflow.executors + assert "specialist_a" in workflow.executors + assert "specialist_b" in workflow.executors + + +async def test_handoff_with_participant_factories(): + """Test workflow creation using participant_factories.""" + call_count = 0 + + def create_triage() -> _RecordingAgent: + nonlocal call_count + call_count += 1 + return _RecordingAgent(name="triage", handoff_to="specialist") + + def create_specialist() -> _RecordingAgent: + nonlocal call_count + call_count += 1 + return _RecordingAgent(name="specialist") + + workflow = ( + HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) + .set_coordinator("triage") + .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2) + .build() + ) + + # Factories should be called during build + assert call_count == 2 + + events = await _drain(workflow.run_stream("Need help")) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + # Follow-up message + events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "More details"})) + outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] + assert outputs + + +async def test_handoff_participant_factories_reusable_builder(): + """Test that the builder can be reused to build multiple workflows with factories.""" + call_count = 0 + + def create_triage() -> _RecordingAgent: + nonlocal call_count + call_count += 1 + return _RecordingAgent(name="triage", handoff_to="specialist") + + def create_specialist() -> _RecordingAgent: + nonlocal call_count + call_count += 1 + return _RecordingAgent(name="specialist") + + builder = HandoffBuilder( + participant_factories={"triage": create_triage, "specialist": create_specialist} + ).set_coordinator("triage") + + # Build first workflow + wf1 = builder.build() + assert call_count == 2 + + # Build second workflow + wf2 = builder.build() + assert call_count == 4 + + # Verify that the two workflows have different agent instances + assert wf1.executors["triage"] is not wf2.executors["triage"] + assert wf1.executors["specialist"] is not wf2.executors["specialist"] + + +async def test_handoff_with_participant_factories_and_add_handoff(): + """Test that .add_handoff() works correctly with participant_factories.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage", handoff_to="specialist_a") + + def create_specialist_a() -> _RecordingAgent: + return _RecordingAgent(name="specialist_a", handoff_to="specialist_b") + + def create_specialist_b() -> _RecordingAgent: + return _RecordingAgent(name="specialist_b") + + workflow = ( + HandoffBuilder( + participant_factories={ + "triage": create_triage, + "specialist_a": create_specialist_a, + "specialist_b": create_specialist_b, + } + ) + .set_coordinator("triage") + .add_handoff("triage", ["specialist_a", "specialist_b"]) + .add_handoff("specialist_a", "specialist_b") + .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) + .build() + ) + + # Start conversation - triage hands off to specialist_a + events = await _drain(workflow.run_stream("Initial request")) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + # Verify specialist_a executor exists and was called + assert "specialist_a" in workflow.executors + + # Second user message - specialist_a hands off to specialist_b + events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need escalation"})) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + # Verify specialist_b executor exists + assert "specialist_b" in workflow.executors + + +async def test_handoff_participant_factories_with_checkpointing(): + """Test checkpointing with participant_factories.""" + from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage + + storage = InMemoryCheckpointStorage() + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage", handoff_to="specialist") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + workflow = ( + HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) + .set_coordinator("triage") + .with_checkpointing(storage) + .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2) + .build() + ) + + # Run workflow and capture output + events = await _drain(workflow.run_stream("checkpoint test")) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "follow up"})) + outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] + assert outputs, "Should have workflow output after termination condition is met" + + # List checkpoints - just verify they were created + checkpoints = await storage.list_checkpoints() + assert checkpoints, "Checkpoints should be created during workflow execution" + + +def test_handoff_set_coordinator_with_factory_name(): + """Test that set_coordinator accepts factory name as string.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + builder = HandoffBuilder( + participant_factories={"triage": create_triage, "specialist": create_specialist} + ).set_coordinator("triage") + + workflow = builder.build() + assert "triage" in workflow.executors + + +def test_handoff_add_handoff_with_factory_names(): + """Test that add_handoff accepts factory names as strings.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage", handoff_to="specialist_a") + + def create_specialist_a() -> _RecordingAgent: + return _RecordingAgent(name="specialist_a") + + def create_specialist_b() -> _RecordingAgent: + return _RecordingAgent(name="specialist_b") + + builder = ( + HandoffBuilder( + participant_factories={ + "triage": create_triage, + "specialist_a": create_specialist_a, + "specialist_b": create_specialist_b, + } + ) + .set_coordinator("triage") + .add_handoff("triage", ["specialist_a", "specialist_b"]) + ) + + workflow = builder.build() + assert "triage" in workflow.executors + assert "specialist_a" in workflow.executors + assert "specialist_b" in workflow.executors + + +async def test_handoff_participant_factories_autonomous_mode(): + """Test autonomous mode with participant_factories.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage", handoff_to="specialist") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + workflow = ( + HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) + .set_coordinator("triage") + .with_interaction_mode("autonomous", autonomous_turn_limit=2) + .build() + ) + + events = await _drain(workflow.run_stream("Issue")) + outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] + assert outputs, "Autonomous mode should yield output" + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert not requests, "Autonomous mode should not request user input" + + +async def test_handoff_participant_factories_with_request_info(): + """Test that .with_request_info() works with participant_factories.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + builder = ( + HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) + .set_coordinator("triage") + .with_request_info(agents=["triage"]) + ) + + workflow = builder.build() + assert "triage" in workflow.executors + + +def test_handoff_participant_factories_invalid_coordinator_name(): + """Test that set_coordinator raises error for non-existent factory name.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + with pytest.raises( + ValueError, match="coordinator factory name 'nonexistent' is not part of the participant_factories list" + ): + (HandoffBuilder(participant_factories={"triage": create_triage}).set_coordinator("nonexistent").build()) + + +def test_handoff_participant_factories_invalid_handoff_target(): + """Test that add_handoff raises error for non-existent target factory name.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage") + + def create_specialist() -> _RecordingAgent: + return _RecordingAgent(name="specialist") + + with pytest.raises(ValueError, match="Target factory name 'nonexistent' is not in the participant_factories list"): + ( + HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) + .set_coordinator("triage") + .add_handoff("triage", "nonexistent") + .build() + ) + + +async def test_handoff_participant_factories_enable_return_to_previous(): + """Test return_to_previous works with participant_factories.""" + + def create_triage() -> _RecordingAgent: + return _RecordingAgent(name="triage", handoff_to="specialist_a") + + def create_specialist_a() -> _RecordingAgent: + return _RecordingAgent(name="specialist_a", handoff_to="specialist_b") + + def create_specialist_b() -> _RecordingAgent: + return _RecordingAgent(name="specialist_b") + + workflow = ( + HandoffBuilder( + participant_factories={ + "triage": create_triage, + "specialist_a": create_specialist_a, + "specialist_b": create_specialist_b, + } + ) + .set_coordinator("triage") + .add_handoff("triage", ["specialist_a", "specialist_b"]) + .add_handoff("specialist_a", "specialist_b") + .enable_return_to_previous(True) + .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) + .build() + ) + + # Start conversation - triage hands off to specialist_a + events = await _drain(workflow.run_stream("Initial request")) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + # Second user message - specialist_a hands off to specialist_b + events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need escalation"})) + requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] + assert requests + + # Third user message - should route back to specialist_b (return to previous) + events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up"})) + outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] + assert outputs or [ev for ev in events if isinstance(ev, RequestInfoEvent)] + + +# endregion Participant Factory Tests + + +async def test_handoff_user_input_request_checkpoint_excludes_conversation(): + """Test that HandoffUserInputRequest serialization excludes conversation to prevent duplication. + + Issue #2667: When checkpointing a workflow with a pending HandoffUserInputRequest, + the conversation field gets serialized twice: once in the RequestInfoEvent's data + and once in the coordinator's conversation state. On restore, this causes duplicate + messages. + + The fix is to exclude the conversation field during checkpoint serialization since + the conversation is already preserved in the coordinator's state. + """ + # Create a conversation history + conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi there!"), + ChatMessage(role=Role.USER, text="Help me"), + ] + + # Create a HandoffUserInputRequest with the conversation + request = HandoffUserInputRequest( + conversation=conversation, + awaiting_agent_id="specialist_agent", + prompt="Please provide your input", + source_executor_id="gateway", + ) + + # Encode the request (simulating checkpoint save) + encoded = encode_checkpoint_value(request) + + # Verify conversation is NOT in the encoded output + # The fix should exclude conversation from serialization + assert isinstance(encoded, dict) + + # If using MODEL_MARKER strategy (to_dict/from_dict) + if "__af_model__" in encoded or "__af_dataclass__" in encoded: + value = encoded.get("value", {}) + assert "conversation" not in value, "conversation should be excluded from checkpoint serialization" + + # Decode the request (simulating checkpoint restore) + decoded = decode_checkpoint_value(encoded) + + # Verify the decoded request is a HandoffUserInputRequest + assert isinstance(decoded, HandoffUserInputRequest) + + # Verify other fields are preserved + assert decoded.awaiting_agent_id == "specialist_agent" + assert decoded.prompt == "Please provide your input" + assert decoded.source_executor_id == "gateway" + + # Conversation should be an empty list after deserialization + # (will be reconstructed from coordinator state on restore) + assert decoded.conversation == [] + + +async def test_handoff_user_input_request_roundtrip_preserves_metadata(): + """Test that non-conversation fields survive checkpoint roundtrip.""" + request = HandoffUserInputRequest( + conversation=[ChatMessage(role=Role.USER, text="test")], + awaiting_agent_id="test_agent", + prompt="Enter your response", + source_executor_id="test_gateway", + ) + + # Roundtrip through checkpoint encoding + encoded = encode_checkpoint_value(request) + decoded = decode_checkpoint_value(encoded) + + assert isinstance(decoded, HandoffUserInputRequest) + assert decoded.awaiting_agent_id == request.awaiting_agent_id + assert decoded.prompt == request.prompt + assert decoded.source_executor_id == request.source_executor_id + + +async def test_request_info_event_with_handoff_user_input_request(): + """Test RequestInfoEvent serialization with HandoffUserInputRequest data.""" + conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="How can I help?"), + ] + + request = HandoffUserInputRequest( + conversation=conversation, + awaiting_agent_id="specialist", + prompt="Provide input", + source_executor_id="gateway", + ) + + # Create a RequestInfoEvent wrapping the request + event = RequestInfoEvent( + request_id="test-request-123", + source_executor_id="gateway", + request_data=request, + response_type=object, + ) + + # Serialize the event + event_dict = event.to_dict() + + # Verify the data field doesn't contain conversation + data_encoded = event_dict["data"] + if isinstance(data_encoded, dict) and ("__af_model__" in data_encoded or "__af_dataclass__" in data_encoded): + value = data_encoded.get("value", {}) + assert "conversation" not in value + + # Deserialize and verify + restored_event = RequestInfoEvent.from_dict(event_dict) + assert isinstance(restored_event.data, HandoffUserInputRequest) + assert restored_event.data.awaiting_agent_id == "specialist" + assert restored_event.data.conversation == [] + + +async def test_handoff_user_input_request_to_dict_excludes_conversation(): + """Test that to_dict() method excludes conversation field.""" + conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi!"), + ] + + request = HandoffUserInputRequest( + conversation=conversation, + awaiting_agent_id="agent1", + prompt="Enter input", + source_executor_id="gateway", + ) + + # Call to_dict directly + data = request.to_dict() + + # Verify conversation is excluded + assert "conversation" not in data + assert data["awaiting_agent_id"] == "agent1" + assert data["prompt"] == "Enter input" + assert data["source_executor_id"] == "gateway" + + +async def test_handoff_user_input_request_from_dict_creates_empty_conversation(): + """Test that from_dict() creates an instance with empty conversation.""" + data = { + "awaiting_agent_id": "agent1", + "prompt": "Enter input", + "source_executor_id": "gateway", + } + + request = HandoffUserInputRequest.from_dict(data) + + assert request.conversation == [] + assert request.awaiting_agent_id == "agent1" + assert request.prompt == "Enter input" + assert request.source_executor_id == "gateway" + + +async def test_user_input_gateway_resume_handles_empty_conversation(): + """Test that _UserInputGateway.resume_from_user handles post-restore scenario. + + After checkpoint restore, the HandoffUserInputRequest will have an empty + conversation. The gateway should handle this by sending only the new user + messages to the coordinator. + """ + from unittest.mock import AsyncMock + + # Create a gateway + gateway = _UserInputGateway( + starting_agent_id="coordinator", + prompt="Enter input", + id="test-gateway", + ) + + # Simulate post-restore: request with empty conversation + restored_request = HandoffUserInputRequest( + conversation=[], # Empty after restore + awaiting_agent_id="specialist", + prompt="Enter input", + source_executor_id="test-gateway", + ) + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Call resume_from_user with a user response + await gateway.resume_from_user(restored_request, "New user message", mock_ctx) + + # Verify send_message was called + mock_ctx.send_message.assert_called_once() + + # Get the message that was sent + call_args = mock_ctx.send_message.call_args + sent_message = call_args[0][0] + + # Verify it's a _ConversationWithUserInput + assert isinstance(sent_message, _ConversationWithUserInput) + + # Verify it contains only the new user message (not any history) + assert len(sent_message.full_conversation) == 1 + assert sent_message.full_conversation[0].role == Role.USER + assert sent_message.full_conversation[0].text == "New user message" + + +async def test_user_input_gateway_resume_with_full_conversation(): + """Test that _UserInputGateway.resume_from_user handles normal flow correctly. + + In normal flow (no checkpoint restore), the HandoffUserInputRequest has + the full conversation. The gateway should send the full conversation + plus the new user messages. + """ + from unittest.mock import AsyncMock + + # Create a gateway + gateway = _UserInputGateway( + starting_agent_id="coordinator", + prompt="Enter input", + id="test-gateway", + ) + + # Normal flow: request with full conversation + normal_request = HandoffUserInputRequest( + conversation=[ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi!"), + ], + awaiting_agent_id="specialist", + prompt="Enter input", + source_executor_id="test-gateway", + ) + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Call resume_from_user with a user response + await gateway.resume_from_user(normal_request, "Follow up message", mock_ctx) + + # Verify send_message was called + mock_ctx.send_message.assert_called_once() + + # Get the message that was sent + call_args = mock_ctx.send_message.call_args + sent_message = call_args[0][0] + + # Verify it's a _ConversationWithUserInput + assert isinstance(sent_message, _ConversationWithUserInput) + + # Verify it contains the full conversation plus new user message + assert len(sent_message.full_conversation) == 3 + assert sent_message.full_conversation[0].text == "Hello" + assert sent_message.full_conversation[1].text == "Hi!" + assert sent_message.full_conversation[2].text == "Follow up message" + + +async def test_coordinator_handle_user_input_post_restore(): + """Test that _HandoffCoordinator.handle_user_input handles post-restore correctly. + + After checkpoint restore, the coordinator has its conversation restored, + and the gateway sends only the new user messages. The coordinator should + append these to its existing conversation rather than replacing. + """ + from unittest.mock import AsyncMock + + from agent_framework._workflows._handoff import _HandoffCoordinator + + # Create a coordinator with pre-existing conversation (simulating restored state) + coordinator = _HandoffCoordinator( + starting_agent_id="triage", + specialist_ids={"specialist_a": "specialist_a"}, + input_gateway_id="gateway", + termination_condition=lambda conv: False, + id="test-coordinator", + ) + + # Simulate restored conversation + coordinator._conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi there!"), + ChatMessage(role=Role.USER, text="Help me"), + ChatMessage(role=Role.ASSISTANT, text="Sure, what do you need?"), + ] + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Simulate post-restore: only new user message with explicit flag + incoming = _ConversationWithUserInput( + full_conversation=[ChatMessage(role=Role.USER, text="I need shipping help")], + is_post_restore=True, + ) + + # Handle the user input + await coordinator.handle_user_input(incoming, mock_ctx) + + # Verify conversation was appended, not replaced + assert len(coordinator._conversation) == 5 + assert coordinator._conversation[0].text == "Hello" + assert coordinator._conversation[1].text == "Hi there!" + assert coordinator._conversation[2].text == "Help me" + assert coordinator._conversation[3].text == "Sure, what do you need?" + assert coordinator._conversation[4].text == "I need shipping help" + + +async def test_coordinator_handle_user_input_normal_flow(): + """Test that _HandoffCoordinator.handle_user_input handles normal flow correctly. + + In normal flow (no restore), the gateway sends the full conversation. + The coordinator should replace its conversation with the incoming one. + """ + from unittest.mock import AsyncMock + + from agent_framework._workflows._handoff import _HandoffCoordinator + + # Create a coordinator + coordinator = _HandoffCoordinator( + starting_agent_id="triage", + specialist_ids={"specialist_a": "specialist_a"}, + input_gateway_id="gateway", + termination_condition=lambda conv: False, + id="test-coordinator", + ) + + # Set some initial conversation + coordinator._conversation = [ + ChatMessage(role=Role.USER, text="Old message"), + ] + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Normal flow: full conversation including new user message (is_post_restore=False by default) + incoming = _ConversationWithUserInput( + full_conversation=[ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi!"), + ChatMessage(role=Role.USER, text="New message"), + ], + is_post_restore=False, + ) + + # Handle the user input + await coordinator.handle_user_input(incoming, mock_ctx) + + # Verify conversation was replaced (normal flow with full history) + assert len(coordinator._conversation) == 3 + assert coordinator._conversation[0].text == "Hello" + assert coordinator._conversation[1].text == "Hi!" + assert coordinator._conversation[2].text == "New message" + + +async def test_coordinator_handle_user_input_multiple_consecutive_user_messages(): + """Test that multiple consecutive USER messages in normal flow are handled correctly. + + This is a regression test for the edge case where a user submits multiple consecutive + USER messages. The explicit is_post_restore flag ensures this doesn't get incorrectly + detected as a post-restore scenario. + """ + from unittest.mock import AsyncMock + + from agent_framework._workflows._handoff import _HandoffCoordinator + + # Create a coordinator with existing conversation + coordinator = _HandoffCoordinator( + starting_agent_id="triage", + specialist_ids={"specialist_a": "specialist_a"}, + input_gateway_id="gateway", + termination_condition=lambda conv: False, + id="test-coordinator", + ) + + # Set existing conversation with 4 messages + coordinator._conversation = [ + ChatMessage(role=Role.USER, text="Original message 1"), + ChatMessage(role=Role.ASSISTANT, text="Response 1"), + ChatMessage(role=Role.USER, text="Original message 2"), + ChatMessage(role=Role.ASSISTANT, text="Response 2"), + ] + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Normal flow: User sends multiple consecutive USER messages + # This should REPLACE the conversation, not append to it + incoming = _ConversationWithUserInput( + full_conversation=[ + ChatMessage(role=Role.USER, text="New user message 1"), + ChatMessage(role=Role.USER, text="New user message 2"), + ], + is_post_restore=False, # Explicit flag - this is normal flow + ) + + # Handle the user input + await coordinator.handle_user_input(incoming, mock_ctx) + + # Verify conversation was REPLACED (not appended) + # Without the explicit flag, the old heuristic might incorrectly append + assert len(coordinator._conversation) == 2 + assert coordinator._conversation[0].text == "New user message 1" + assert coordinator._conversation[1].text == "New user message 2" diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index 71cfc6752a..4ee16ddb5f 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -876,3 +876,204 @@ def test_magentic_builder_does_not_have_human_input_hook(): "MagenticBuilder should not have with_human_input_hook - " "use with_plan_review() or with_human_input_on_stall() instead" ) + + +# region Message Deduplication Tests + + +async def test_magentic_no_duplicate_messages_with_conversation_history(): + """Test that passing list[ChatMessage] does not create duplicate messages in chat_history. + + When a frontend passes conversation history as list[ChatMessage], the last message + (task) should not be duplicated in the orchestrator's chat_history. + """ + manager = FakeManager(max_round_count=10) + manager.satisfied_after_signoff = True # Complete immediately after first agent response + + wf = MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager).build() + + # Simulate frontend passing conversation history + conversation: list[ChatMessage] = [ + ChatMessage(role=Role.USER, text="previous question"), + ChatMessage(role=Role.ASSISTANT, text="previous answer"), + ChatMessage(role=Role.USER, text="current task"), + ] + + # Get orchestrator to inspect chat_history after run + orchestrator = None + for executor in wf.executors.values(): + if isinstance(executor, MagenticOrchestratorExecutor): + orchestrator = executor + break + + events: list[WorkflowEvent] = [] + async for event in wf.run_stream(conversation): + events.append(event) + if isinstance(event, WorkflowStatusEvent) and event.state in ( + WorkflowRunState.IDLE, + WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, + ): + break + + assert orchestrator is not None + assert orchestrator._context is not None # type: ignore[reportPrivateUsage] + + # Count occurrences of each message text in chat_history + history = orchestrator._context.chat_history # type: ignore[reportPrivateUsage] + user_task_count = sum(1 for msg in history if msg.text == "current task") + prev_question_count = sum(1 for msg in history if msg.text == "previous question") + prev_answer_count = sum(1 for msg in history if msg.text == "previous answer") + + # Each input message should appear exactly once (no duplicates) + assert prev_question_count == 1, f"Expected 1 'previous question', got {prev_question_count}" + assert prev_answer_count == 1, f"Expected 1 'previous answer', got {prev_answer_count}" + assert user_task_count == 1, f"Expected 1 'current task', got {user_task_count}" + + +async def test_magentic_agent_executor_no_duplicate_messages_on_broadcast(): + """Test that MagenticAgentExecutor does not duplicate messages from broadcasts. + + When the orchestrator broadcasts the task ledger to all agents, each agent + should receive it exactly once, not multiple times. + """ + backing_executor = _DummyExec("backing") + agent_exec = MagenticAgentExecutor(backing_executor, "agentA") + + # Simulate orchestrator sending a broadcast message + broadcast_msg = ChatMessage( + role=Role.ASSISTANT, + text="Task ledger content", + author_name="magentic_manager", + ) + + # Simulate the same message being received multiple times (e.g., from checkpoint restore + live) + from agent_framework._workflows._magentic import _MagenticResponseMessage + + response1 = _MagenticResponseMessage(body=broadcast_msg, broadcast=True) + response2 = _MagenticResponseMessage(body=broadcast_msg, broadcast=True) + + # Create a mock context + from unittest.mock import AsyncMock, MagicMock + + mock_context = MagicMock() + mock_context.send_message = AsyncMock() + + # Call the handler twice with the same message + await agent_exec.handle_response_message(response1, mock_context) # type: ignore[arg-type] + await agent_exec.handle_response_message(response2, mock_context) # type: ignore[arg-type] + + # Count how many times the broadcast message appears + history = agent_exec._chat_history # type: ignore[reportPrivateUsage] + broadcast_count = sum(1 for msg in history if msg.text == "Task ledger content") + + # Each broadcast should be recorded (this is expected behavior - broadcasts are additive) + # The test documents current behavior. If dedup is needed, this assertion would change. + assert broadcast_count == 2, ( + f"Expected 2 broadcasts (current behavior is additive), got {broadcast_count}. " + "If deduplication is required, update the handler logic." + ) + + +async def test_magentic_context_no_duplicate_on_reset(): + """Test that MagenticContext.reset() clears chat_history without leaving duplicates.""" + ctx = MagenticContext( + task=ChatMessage(role=Role.USER, text="task"), + participant_descriptions={"Alice": "Researcher"}, + ) + + # Add some history + ctx.chat_history.append(ChatMessage(role=Role.ASSISTANT, text="response1")) + ctx.chat_history.append(ChatMessage(role=Role.ASSISTANT, text="response2")) + assert len(ctx.chat_history) == 2 + + # Reset + ctx.reset() + + # Verify clean slate + assert len(ctx.chat_history) == 0, "chat_history should be empty after reset" + + # Add new history + ctx.chat_history.append(ChatMessage(role=Role.ASSISTANT, text="new_response")) + assert len(ctx.chat_history) == 1, "Should have exactly 1 message after adding to reset context" + + +async def test_magentic_start_message_messages_list_integrity(): + """Test that _MagenticStartMessage preserves message list without internal duplication.""" + conversation: list[ChatMessage] = [ + ChatMessage(role=Role.USER, text="msg1"), + ChatMessage(role=Role.ASSISTANT, text="msg2"), + ChatMessage(role=Role.USER, text="msg3"), + ] + + start_msg = _MagenticStartMessage(conversation) + + # Verify messages list is preserved + assert len(start_msg.messages) == 3, f"Expected 3 messages, got {len(start_msg.messages)}" + + # Verify task is the last message (not a copy) + assert start_msg.task is start_msg.messages[-1], "task should be the same object as messages[-1]" + assert start_msg.task.text == "msg3" + + +async def test_magentic_checkpoint_restore_no_duplicate_history(): + """Test that checkpoint restore does not create duplicate messages in chat_history.""" + manager = FakeManager(max_round_count=10) + storage = InMemoryCheckpointStorage() + + wf = ( + MagenticBuilder() + .participants(agentA=_DummyExec("agentA")) + .with_standard_manager(manager) + .with_checkpointing(storage) + .build() + ) + + # Run with conversation history to create initial checkpoint + conversation: list[ChatMessage] = [ + ChatMessage(role=Role.USER, text="history_msg"), + ChatMessage(role=Role.USER, text="task_msg"), + ] + + async for event in wf.run_stream(conversation): + if isinstance(event, WorkflowStatusEvent) and event.state in ( + WorkflowRunState.IDLE, + WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, + ): + break + + # Get checkpoint + checkpoints = await storage.list_checkpoints() + assert len(checkpoints) > 0, "Should have created checkpoints" + + latest_checkpoint = checkpoints[-1] + + # Load checkpoint and verify no duplicates in shared state + checkpoint_data = await storage.load_checkpoint(latest_checkpoint.checkpoint_id) + assert checkpoint_data is not None + + # Check the magentic_context in the checkpoint + for _, executor_state in checkpoint_data.metadata.items(): + if isinstance(executor_state, dict) and "magentic_context" in executor_state: + ctx_data = executor_state["magentic_context"] + chat_history = ctx_data.get("chat_history", []) + + # Count unique messages by text + texts = [ + msg.get("text") or (msg.get("contents", [{}])[0].get("text") if msg.get("contents") else None) + for msg in chat_history + ] + text_counts: dict[str, int] = {} + for text in texts: + if text: + text_counts[text] = text_counts.get(text, 0) + 1 + + # Input messages should not be duplicated + assert text_counts.get("history_msg", 0) <= 1, ( + f"'history_msg' appears {text_counts.get('history_msg', 0)} times in checkpoint - expected <= 1" + ) + assert text_counts.get("task_msg", 0) <= 1, ( + f"'task_msg' appears {text_counts.get('task_msg', 0)} times in checkpoint - expected <= 1" + ) + + +# endregion diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index 009ead6edd..51b3544b22 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -9,18 +9,23 @@ AgentRunResponse, AgentRunResponseUpdate, AgentRunUpdateEvent, + AgentThread, ChatMessage, + ChatMessageStore, + DataContent, Executor, FunctionApprovalRequestContent, FunctionApprovalResponseContent, FunctionCallContent, Role, TextContent, + UriContent, UsageContent, UsageDetails, WorkflowAgent, WorkflowBuilder, WorkflowContext, + executor, handler, response_handler, ) @@ -75,6 +80,31 @@ async def handle_request_response( await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=update)) +class ConversationHistoryCapturingExecutor(Executor): + """Executor that captures the received conversation history for verification.""" + + def __init__(self, id: str): + super().__init__(id=id) + self.received_messages: list[ChatMessage] = [] + + @handler + async def handle_message(self, messages: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: + # Capture all received messages + self.received_messages = list(messages) + + # Count messages by role for the response + message_count = len(messages) + response_text = f"Received {message_count} messages" + + response_message = ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=response_text)]) + + streaming_update = AgentRunResponseUpdate( + contents=[TextContent(text=response_text)], role=Role.ASSISTANT, message_id=str(uuid.uuid4()) + ) + await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=streaming_update)) + await ctx.send_message([response_message]) + + class TestWorkflowAgent: """Test cases for WorkflowAgent end-to-end functionality.""" @@ -257,6 +287,240 @@ async def handle_bool(self, message: bool, context: WorkflowContext[Any]) -> Non with pytest.raises(ValueError, match="Workflow's start executor cannot handle list\\[ChatMessage\\]"): workflow.as_agent() + async def test_workflow_as_agent_yield_output_surfaces_as_agent_response(self) -> None: + """Test that ctx.yield_output() in a workflow executor surfaces as agent output when using .as_agent(). + + This validates the fix for issue #2813: WorkflowOutputEvent should be converted to + AgentRunResponseUpdate when the workflow is wrapped via .as_agent(). + """ + + @executor + async def yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: + # Extract text from input for demonstration + input_text = messages[0].text if messages else "no input" + await ctx.yield_output(f"processed: {input_text}") + + workflow = WorkflowBuilder().set_start_executor(yielding_executor).build() + + # Run directly - should return WorkflowOutputEvent in result + direct_result = await workflow.run([ChatMessage(role=Role.USER, contents=[TextContent(text="hello")])]) + direct_outputs = direct_result.get_outputs() + assert len(direct_outputs) == 1 + assert direct_outputs[0] == "processed: hello" + + # Run as agent - yield_output should surface as agent response message + agent = workflow.as_agent("test-agent") + agent_result = await agent.run("hello") + + assert isinstance(agent_result, AgentRunResponse) + assert len(agent_result.messages) == 1 + assert agent_result.messages[0].text == "processed: hello" + + async def test_workflow_as_agent_yield_output_surfaces_in_run_stream(self) -> None: + """Test that ctx.yield_output() surfaces as AgentRunResponseUpdate when streaming.""" + + @executor + async def yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: + await ctx.yield_output("first output") + await ctx.yield_output("second output") + + workflow = WorkflowBuilder().set_start_executor(yielding_executor).build() + agent = workflow.as_agent("test-agent") + + updates: list[AgentRunResponseUpdate] = [] + async for update in agent.run_stream("hello"): + updates.append(update) + + # Should have received updates for both yield_output calls + texts = [u.text for u in updates if u.text] + assert "first output" in texts + assert "second output" in texts + + async def test_workflow_as_agent_yield_output_with_content_types(self) -> None: + """Test that yield_output preserves different content types (TextContent, DataContent, etc.).""" + + @executor + async def content_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: + # Yield different content types + await ctx.yield_output(TextContent(text="text content")) + await ctx.yield_output(DataContent(data=b"binary data", media_type="application/octet-stream")) + await ctx.yield_output(UriContent(uri="https://example.com/image.png", media_type="image/png")) + + workflow = WorkflowBuilder().set_start_executor(content_yielding_executor).build() + agent = workflow.as_agent("content-test-agent") + + result = await agent.run("test") + + assert isinstance(result, AgentRunResponse) + assert len(result.messages) == 3 + + # Verify each content type is preserved + assert isinstance(result.messages[0].contents[0], TextContent) + assert result.messages[0].contents[0].text == "text content" + + assert isinstance(result.messages[1].contents[0], DataContent) + assert result.messages[1].contents[0].media_type == "application/octet-stream" + + assert isinstance(result.messages[2].contents[0], UriContent) + assert result.messages[2].contents[0].uri == "https://example.com/image.png" + + async def test_workflow_as_agent_yield_output_with_chat_message(self) -> None: + """Test that yield_output with ChatMessage preserves the message structure.""" + + @executor + async def chat_message_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: + msg = ChatMessage( + role=Role.ASSISTANT, + contents=[TextContent(text="response text")], + author_name="custom-author", + ) + await ctx.yield_output(msg) + + workflow = WorkflowBuilder().set_start_executor(chat_message_executor).build() + agent = workflow.as_agent("chat-msg-agent") + + result = await agent.run("test") + + assert len(result.messages) == 1 + assert result.messages[0].role == Role.ASSISTANT + assert result.messages[0].text == "response text" + assert result.messages[0].author_name == "custom-author" + + async def test_workflow_as_agent_yield_output_sets_raw_representation(self) -> None: + """Test that yield_output sets raw_representation with the original data.""" + + # A custom object to verify raw_representation preserves the original data + class CustomData: + def __init__(self, value: int): + self.value = value + + def __str__(self) -> str: + return f"CustomData({self.value})" + + @executor + async def raw_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: + # Yield different types of data + await ctx.yield_output("simple string") + await ctx.yield_output(TextContent(text="text content")) + custom = CustomData(42) + await ctx.yield_output(custom) + + workflow = WorkflowBuilder().set_start_executor(raw_yielding_executor).build() + agent = workflow.as_agent("raw-test-agent") + + updates: list[AgentRunResponseUpdate] = [] + async for update in agent.run_stream("test"): + updates.append(update) + + # Should have 3 updates + assert len(updates) == 3 + + # Verify raw_representation is set for each update + assert updates[0].raw_representation == "simple string" + assert isinstance(updates[1].raw_representation, TextContent) + assert updates[1].raw_representation.text == "text content" + assert isinstance(updates[2].raw_representation, CustomData) + assert updates[2].raw_representation.value == 42 + + async def test_thread_conversation_history_included_in_workflow_run(self) -> None: + """Test that conversation history from thread is included when running WorkflowAgent. + + This verifies that when a thread with existing messages is provided to agent.run(), + the workflow receives the complete conversation history (thread history + new messages). + """ + # Create an executor that captures all received messages + capturing_executor = ConversationHistoryCapturingExecutor(id="capturing") + workflow = WorkflowBuilder().set_start_executor(capturing_executor).build() + agent = WorkflowAgent(workflow=workflow, name="Thread History Test Agent") + + # Create a thread with existing conversation history + history_messages = [ + ChatMessage(role=Role.USER, text="Previous user message"), + ChatMessage(role=Role.ASSISTANT, text="Previous assistant response"), + ] + message_store = ChatMessageStore(messages=history_messages) + thread = AgentThread(message_store=message_store) + + # Run the agent with the thread and a new message + new_message = "New user question" + await agent.run(new_message, thread=thread) + + # Verify the executor received both history AND new message + assert len(capturing_executor.received_messages) == 3 + + # Verify the order: history first, then new message + assert capturing_executor.received_messages[0].text == "Previous user message" + assert capturing_executor.received_messages[1].text == "Previous assistant response" + assert capturing_executor.received_messages[2].text == "New user question" + + async def test_thread_conversation_history_included_in_workflow_stream(self) -> None: + """Test that conversation history from thread is included when streaming WorkflowAgent. + + This verifies that run_stream also includes thread history. + """ + # Create an executor that captures all received messages + capturing_executor = ConversationHistoryCapturingExecutor(id="capturing_stream") + workflow = WorkflowBuilder().set_start_executor(capturing_executor).build() + agent = WorkflowAgent(workflow=workflow, name="Thread Stream Test Agent") + + # Create a thread with existing conversation history + history_messages = [ + ChatMessage(role=Role.SYSTEM, text="You are a helpful assistant"), + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi there!"), + ] + message_store = ChatMessageStore(messages=history_messages) + thread = AgentThread(message_store=message_store) + + # Stream from the agent with the thread and a new message + async for _ in agent.run_stream("How are you?", thread=thread): + pass + + # Verify the executor received all messages (3 from history + 1 new) + assert len(capturing_executor.received_messages) == 4 + + # Verify the order + assert capturing_executor.received_messages[0].text == "You are a helpful assistant" + assert capturing_executor.received_messages[1].text == "Hello" + assert capturing_executor.received_messages[2].text == "Hi there!" + assert capturing_executor.received_messages[3].text == "How are you?" + + async def test_empty_thread_works_correctly(self) -> None: + """Test that an empty thread (no message store) works correctly.""" + capturing_executor = ConversationHistoryCapturingExecutor(id="empty_thread_test") + workflow = WorkflowBuilder().set_start_executor(capturing_executor).build() + agent = WorkflowAgent(workflow=workflow, name="Empty Thread Test Agent") + + # Create an empty thread + thread = AgentThread() + + # Run with the empty thread + await agent.run("Just a new message", thread=thread) + + # Should only receive the new message + assert len(capturing_executor.received_messages) == 1 + assert capturing_executor.received_messages[0].text == "Just a new message" + + async def test_checkpoint_storage_passed_to_workflow(self) -> None: + """Test that checkpoint_storage parameter is passed through to the workflow.""" + from agent_framework import InMemoryCheckpointStorage + + capturing_executor = ConversationHistoryCapturingExecutor(id="checkpoint_test") + workflow = WorkflowBuilder().set_start_executor(capturing_executor).build() + agent = WorkflowAgent(workflow=workflow, name="Checkpoint Test Agent") + + # Create checkpoint storage + checkpoint_storage = InMemoryCheckpointStorage() + + # Run with checkpoint storage enabled + async for _ in agent.run_stream("Test message", checkpoint_storage=checkpoint_storage): + pass + + # Drain workflow events to get checkpoint + # The workflow should have created checkpoints + checkpoints = await checkpoint_storage.list_checkpoints(workflow.id) + assert len(checkpoints) > 0, "Checkpoints should have been created when checkpoint_storage is provided" + class TestWorkflowAgentMergeUpdates: """Test cases specifically for the WorkflowAgent.merge_updates static method.""" diff --git a/python/packages/core/tests/workflow/test_workflow_builder.py b/python/packages/core/tests/workflow/test_workflow_builder.py index a037bf51b6..83c9d41c22 100644 --- a/python/packages/core/tests/workflow/test_workflow_builder.py +++ b/python/packages/core/tests/workflow/test_workflow_builder.py @@ -293,6 +293,20 @@ def test_register_duplicate_name_raises_error(): builder.register_executor(lambda: MockExecutor(id="executor_2"), name="MyExecutor") +def test_register_duplicate_id_raises_error(): + """Test that registering duplicate id raises an error.""" + builder = WorkflowBuilder() + + # Register first executor + builder.register_executor(lambda: MockExecutor(id="executor"), name="MyExecutor1") + builder.register_executor(lambda: MockExecutor(id="executor"), name="MyExecutor2") + builder.set_start_executor("MyExecutor1") + + # Registering second executor with same ID should raise ValueError + with pytest.raises(ValueError, match="Executor with ID 'executor' has already been created."): + builder.build() + + def test_register_agent_basic(): """Test basic agent registration with lazy initialization.""" builder = WorkflowBuilder() @@ -483,7 +497,13 @@ def test_mixing_eager_and_lazy_initialization_error(): builder.register_executor(lambda: MockExecutor(id="Lazy"), name="Lazy") # Mixing eager and lazy should raise an error during add_edge - with pytest.raises(ValueError, match="Both source and target must be either names"): + with pytest.raises( + ValueError, + match=( + r"Both source and target must be either registered factory names \(str\) " + r"or Executor/AgentProtocol instances\." + ), + ): builder.add_edge(eager_executor, "Lazy") diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py new file mode 100644 index 0000000000..864258b76c --- /dev/null +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -0,0 +1,492 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import AsyncIterable +from typing import Annotated, Any + +from agent_framework import ( + AgentRunResponse, + AgentRunResponseUpdate, + AgentThread, + BaseAgent, + ChatMessage, + ConcurrentBuilder, + GroupChatBuilder, + GroupChatStateSnapshot, + HandoffBuilder, + Role, + SequentialBuilder, + TextContent, + WorkflowRunState, + WorkflowStatusEvent, + ai_function, +) +from agent_framework._workflows._const import WORKFLOW_RUN_KWARGS_KEY + +# Track kwargs received by tools during test execution +_received_kwargs: list[dict[str, Any]] = [] + + +def _reset_received_kwargs() -> None: + """Reset the kwargs tracker before each test.""" + _received_kwargs.clear() + + +@ai_function +def tool_with_kwargs( + action: Annotated[str, "The action to perform"], + **kwargs: Any, +) -> str: + """A test tool that captures kwargs for verification.""" + _received_kwargs.append(dict(kwargs)) + custom_data = kwargs.get("custom_data", {}) + user_token = kwargs.get("user_token", {}) + return f"Executed {action} with custom_data={custom_data}, user={user_token.get('user_name', 'unknown')}" + + +class _KwargsCapturingAgent(BaseAgent): + """Test agent that captures kwargs passed to run/run_stream.""" + + captured_kwargs: list[dict[str, Any]] + + def __init__(self, name: str = "test_agent") -> None: + super().__init__(name=name, description="Test agent for kwargs capture") + self.captured_kwargs = [] + + async def run( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentRunResponse: + self.captured_kwargs.append(dict(kwargs)) + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.display_name} response")]) + + async def run_stream( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentRunResponseUpdate]: + self.captured_kwargs.append(dict(kwargs)) + yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.display_name} response")]) + + +class _EchoAgent(BaseAgent): + """Simple agent that echoes back for workflow completion.""" + + async def run( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentRunResponse: + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.display_name} reply")]) + + async def run_stream( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentRunResponseUpdate]: + yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.display_name} reply")]) + + +# region Sequential Builder Tests + + +async def test_sequential_kwargs_flow_to_agent() -> None: + """Test that kwargs passed to SequentialBuilder workflow flow through to agent.""" + agent = _KwargsCapturingAgent(name="seq_agent") + workflow = SequentialBuilder().participants([agent]).build() + + custom_data = {"endpoint": "https://api.example.com", "version": "v1"} + user_token = {"user_name": "alice", "access_level": "admin"} + + async for event in workflow.run_stream( + "test message", + custom_data=custom_data, + user_token=user_token, + ): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Verify agent received kwargs + assert len(agent.captured_kwargs) >= 1, "Agent should have been invoked at least once" + received = agent.captured_kwargs[0] + assert "custom_data" in received, "Agent should receive custom_data kwarg" + assert "user_token" in received, "Agent should receive user_token kwarg" + assert received["custom_data"] == custom_data + assert received["user_token"] == user_token + + +async def test_sequential_kwargs_flow_to_multiple_agents() -> None: + """Test that kwargs flow to all agents in a sequential workflow.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder().participants([agent1, agent2]).build() + + custom_data = {"key": "value"} + + async for event in workflow.run_stream("test", custom_data=custom_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Both agents should have received kwargs + assert len(agent1.captured_kwargs) >= 1, "First agent should be invoked" + assert len(agent2.captured_kwargs) >= 1, "Second agent should be invoked" + assert agent1.captured_kwargs[0].get("custom_data") == custom_data + assert agent2.captured_kwargs[0].get("custom_data") == custom_data + + +async def test_sequential_run_kwargs_flow() -> None: + """Test that kwargs flow through workflow.run() (non-streaming).""" + agent = _KwargsCapturingAgent(name="run_agent") + workflow = SequentialBuilder().participants([agent]).build() + + _ = await workflow.run("test message", custom_data={"test": True}) + + assert len(agent.captured_kwargs) >= 1 + assert agent.captured_kwargs[0].get("custom_data") == {"test": True} + + +# endregion + + +# region Concurrent Builder Tests + + +async def test_concurrent_kwargs_flow_to_agents() -> None: + """Test that kwargs flow to all agents in a concurrent workflow.""" + agent1 = _KwargsCapturingAgent(name="concurrent1") + agent2 = _KwargsCapturingAgent(name="concurrent2") + workflow = ConcurrentBuilder().participants([agent1, agent2]).build() + + custom_data = {"batch_id": "123"} + user_token = {"user_name": "bob"} + + async for event in workflow.run_stream( + "concurrent test", + custom_data=custom_data, + user_token=user_token, + ): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Both agents should have received kwargs + assert len(agent1.captured_kwargs) >= 1, "First concurrent agent should be invoked" + assert len(agent2.captured_kwargs) >= 1, "Second concurrent agent should be invoked" + + for agent in [agent1, agent2]: + received = agent.captured_kwargs[0] + assert received.get("custom_data") == custom_data + assert received.get("user_token") == user_token + + +# endregion + + +# region GroupChat Builder Tests + + +async def test_groupchat_kwargs_flow_to_agents() -> None: + """Test that kwargs flow to agents in a group chat workflow.""" + agent1 = _KwargsCapturingAgent(name="chat1") + agent2 = _KwargsCapturingAgent(name="chat2") + + # Simple selector that takes GroupChatStateSnapshot + turn_count = 0 + + def simple_selector(state: GroupChatStateSnapshot) -> str | None: + nonlocal turn_count + turn_count += 1 + if turn_count > 2: # Stop after 2 turns + return None + # state is a Mapping - access via dict syntax + names = list(state["participants"].keys()) + return names[(turn_count - 1) % len(names)] + + workflow = ( + GroupChatBuilder().participants(chat1=agent1, chat2=agent2).set_select_speakers_func(simple_selector).build() + ) + + custom_data = {"session_id": "group123"} + + async for event in workflow.run_stream("group chat test", custom_data=custom_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # At least one agent should have received kwargs + all_kwargs = agent1.captured_kwargs + agent2.captured_kwargs + assert len(all_kwargs) >= 1, "At least one agent should be invoked in group chat" + + for received in all_kwargs: + assert received.get("custom_data") == custom_data + + +# endregion + + +# region SharedState Verification Tests + + +async def test_kwargs_stored_in_shared_state() -> None: + """Test that kwargs are stored in SharedState with the correct key.""" + from agent_framework import Executor, WorkflowContext, handler + + stored_kwargs: dict[str, Any] | None = None + + class _SharedStateInspector(Executor): + @handler + async def inspect(self, msgs: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: + nonlocal stored_kwargs + stored_kwargs = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + await ctx.send_message(msgs) + + inspector = _SharedStateInspector(id="inspector") + workflow = SequentialBuilder().participants([inspector]).build() + + async for event in workflow.run_stream("test", my_kwarg="my_value", another=123): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + assert stored_kwargs is not None, "kwargs should be stored in SharedState" + assert stored_kwargs.get("my_kwarg") == "my_value" + assert stored_kwargs.get("another") == 123 + + +async def test_empty_kwargs_stored_as_empty_dict() -> None: + """Test that empty kwargs are stored as empty dict in SharedState.""" + from agent_framework import Executor, WorkflowContext, handler + + stored_kwargs: Any = "NOT_CHECKED" + + class _SharedStateChecker(Executor): + @handler + async def check(self, msgs: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: + nonlocal stored_kwargs + stored_kwargs = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) + await ctx.send_message(msgs) + + checker = _SharedStateChecker(id="checker") + workflow = SequentialBuilder().participants([checker]).build() + + # Run without any kwargs + async for event in workflow.run_stream("test"): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # SharedState should have empty dict when no kwargs provided + assert stored_kwargs == {}, f"Expected empty dict, got: {stored_kwargs}" + + +# endregion + + +# region Edge Cases + + +async def test_kwargs_with_none_values() -> None: + """Test that kwargs with None values are passed through correctly.""" + agent = _KwargsCapturingAgent(name="none_test") + workflow = SequentialBuilder().participants([agent]).build() + + async for event in workflow.run_stream("test", optional_param=None, other_param="value"): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + assert len(agent.captured_kwargs) >= 1 + received = agent.captured_kwargs[0] + assert "optional_param" in received + assert received["optional_param"] is None + assert received["other_param"] == "value" + + +async def test_kwargs_with_complex_nested_data() -> None: + """Test that complex nested data structures flow through correctly.""" + agent = _KwargsCapturingAgent(name="nested_test") + workflow = SequentialBuilder().participants([agent]).build() + + complex_data = { + "level1": { + "level2": { + "level3": ["a", "b", "c"], + "number": 42, + }, + "list": [1, 2, {"nested": True}], + }, + "tuple_like": [1, 2, 3], + } + + async for event in workflow.run_stream("test", complex_data=complex_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + assert len(agent.captured_kwargs) >= 1 + received = agent.captured_kwargs[0] + assert received.get("complex_data") == complex_data + + +async def test_kwargs_preserved_across_workflow_reruns() -> None: + """Test that kwargs are correctly isolated between workflow runs.""" + agent = _KwargsCapturingAgent(name="rerun_test") + + # Build separate workflows for each run to avoid "already running" error + workflow1 = SequentialBuilder().participants([agent]).build() + workflow2 = SequentialBuilder().participants([agent]).build() + + # First run + async for event in workflow1.run_stream("run1", run_id="first"): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Second run with different kwargs (using fresh workflow) + async for event in workflow2.run_stream("run2", run_id="second"): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + assert len(agent.captured_kwargs) >= 2 + assert agent.captured_kwargs[0].get("run_id") == "first" + assert agent.captured_kwargs[1].get("run_id") == "second" + + +# endregion + + +# region Handoff Builder Tests + + +async def test_handoff_kwargs_flow_to_agents() -> None: + """Test that kwargs flow to agents in a handoff workflow.""" + agent1 = _KwargsCapturingAgent(name="coordinator") + agent2 = _KwargsCapturingAgent(name="specialist") + + workflow = ( + HandoffBuilder() + .participants([agent1, agent2]) + .set_coordinator(agent1) + .with_interaction_mode("autonomous") + .build() + ) + + custom_data = {"session_id": "handoff123"} + + async for event in workflow.run_stream("handoff test", custom_data=custom_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Coordinator agent should have received kwargs + assert len(agent1.captured_kwargs) >= 1, "Coordinator should be invoked in handoff" + assert agent1.captured_kwargs[0].get("custom_data") == custom_data + + +# endregion + + +# region Magentic Builder Tests + + +async def test_magentic_kwargs_flow_to_agents() -> None: + """Test that kwargs flow to agents in a magentic workflow via MagenticAgentExecutor.""" + from agent_framework import MagenticBuilder + from agent_framework._workflows._magentic import ( + MagenticContext, + MagenticManagerBase, + _MagenticProgressLedger, + _MagenticProgressLedgerItem, + ) + + # Create a mock manager that completes after one round + class _MockManager(MagenticManagerBase): + def __init__(self) -> None: + super().__init__(max_stall_count=3, max_reset_count=None, max_round_count=2) + self.task_ledger = None + + async def plan(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Plan: Test task", author_name="manager") + + async def replan(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Replan: Test task", author_name="manager") + + async def create_progress_ledger(self, context: MagenticContext) -> _MagenticProgressLedger: + # Return completed on first call + return _MagenticProgressLedger( + is_request_satisfied=_MagenticProgressLedgerItem(answer=True, reason="Done"), + is_progress_being_made=_MagenticProgressLedgerItem(answer=True, reason="Progress"), + is_in_loop=_MagenticProgressLedgerItem(answer=False, reason="Not looping"), + instruction_or_question=_MagenticProgressLedgerItem(answer="Complete", reason="Done"), + next_speaker=_MagenticProgressLedgerItem(answer="agent1", reason="First"), + ) + + async def prepare_final_answer(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Final answer", author_name="manager") + + agent = _KwargsCapturingAgent(name="agent1") + manager = _MockManager() + + workflow = MagenticBuilder().participants(agent1=agent).with_standard_manager(manager=manager).build() + + custom_data = {"session_id": "magentic123"} + + async for event in workflow.run_stream("magentic test", custom_data=custom_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # The workflow completes immediately via prepare_final_answer without invoking agents + # because is_request_satisfied=True. This test verifies the kwargs storage path works. + # A more comprehensive integration test would require the manager to select an agent. + + +async def test_magentic_kwargs_stored_in_shared_state() -> None: + """Test that kwargs are stored in SharedState when using MagenticWorkflow.run_stream().""" + from agent_framework import MagenticBuilder + from agent_framework._workflows._magentic import ( + MagenticContext, + MagenticManagerBase, + _MagenticProgressLedger, + _MagenticProgressLedgerItem, + ) + + class _MockManager(MagenticManagerBase): + def __init__(self) -> None: + super().__init__(max_stall_count=3, max_reset_count=None, max_round_count=1) + self.task_ledger = None + + async def plan(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Plan", author_name="manager") + + async def replan(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Replan", author_name="manager") + + async def create_progress_ledger(self, context: MagenticContext) -> _MagenticProgressLedger: + return _MagenticProgressLedger( + is_request_satisfied=_MagenticProgressLedgerItem(answer=True, reason="Done"), + is_progress_being_made=_MagenticProgressLedgerItem(answer=True, reason="Progress"), + is_in_loop=_MagenticProgressLedgerItem(answer=False, reason="Not looping"), + instruction_or_question=_MagenticProgressLedgerItem(answer="Done", reason="Done"), + next_speaker=_MagenticProgressLedgerItem(answer="agent1", reason="First"), + ) + + async def prepare_final_answer(self, context: MagenticContext) -> ChatMessage: + return ChatMessage(role=Role.ASSISTANT, text="Final", author_name="manager") + + agent = _KwargsCapturingAgent(name="agent1") + manager = _MockManager() + + magentic_workflow = MagenticBuilder().participants(agent1=agent).with_standard_manager(manager=manager).build() + + # Use MagenticWorkflow.run_stream() which goes through the kwargs attachment path + custom_data = {"magentic_key": "magentic_value"} + + async for event in magentic_workflow.run_stream("test task", custom_data=custom_data): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + break + + # Verify the workflow completed (kwargs were stored, even if agent wasn't invoked) + # The test validates the code path through MagenticWorkflow.run_stream -> _MagenticStartMessage + + +# endregion diff --git a/python/packages/core/tests/workflow/test_workflow_observability.py b/python/packages/core/tests/workflow/test_workflow_observability.py index 1760361f1a..4c97b850b8 100644 --- a/python/packages/core/tests/workflow/test_workflow_observability.py +++ b/python/packages/core/tests/workflow/test_workflow_observability.py @@ -229,8 +229,10 @@ async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> No assert processing_span.attributes.get("message.payload_type") == "str" -@pytest.mark.parametrize("enable_otel", [False], indirect=True) -async def test_trace_context_disabled_when_tracing_disabled(enable_otel, span_exporter: InMemorySpanExporter) -> None: +@pytest.mark.parametrize("enable_instrumentation", [False], indirect=True) +async def test_trace_context_disabled_when_tracing_disabled( + enable_instrumentation, span_exporter: InMemorySpanExporter +) -> None: """Test that no trace context is added when tracing is disabled.""" # Tracing should be disabled by default executor = MockExecutor("test-executor") @@ -433,7 +435,7 @@ async def handle_message(self, message: str, ctx: WorkflowContext) -> None: assert workflow_span.status.status_code.name == "ERROR" -@pytest.mark.parametrize("enable_otel", [False], indirect=True) +@pytest.mark.parametrize("enable_instrumentation", [False], indirect=True) async def test_message_trace_context_serialization(span_exporter: InMemorySpanExporter) -> None: """Test that message trace context is properly serialized/deserialized.""" ctx = InProcRunnerContext(InMemoryCheckpointStorage()) diff --git a/python/packages/declarative/pyproject.toml b/python/packages/declarative/pyproject.toml index f1032eb064..5a82f3a1e5 100644 --- a/python/packages/declarative/pyproject.toml +++ b/python/packages/declarative/pyproject.toml @@ -4,7 +4,7 @@ description = "Declarative specification support for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" -version = "1.0.0b251209" +version = "1.0.0b251216" license-files = ["LICENSE"] urls.homepage = "https://aka.ms/agent-framework" urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" diff --git a/python/packages/devui/agent_framework_devui/__init__.py b/python/packages/devui/agent_framework_devui/__init__.py index 45d1ea8c2d..9a480d170e 100644 --- a/python/packages/devui/agent_framework_devui/__init__.py +++ b/python/packages/devui/agent_framework_devui/__init__.py @@ -177,9 +177,9 @@ def serve( import os # Only set if not already configured by user - if not os.environ.get("ENABLE_OTEL"): - os.environ["ENABLE_OTEL"] = "true" - logger.info("Set ENABLE_OTEL=true for tracing") + if not os.environ.get("ENABLE_INSTRUMENTATION"): + os.environ["ENABLE_INSTRUMENTATION"] = "true" + logger.info("Set ENABLE_INSTRUMENTATION=true for tracing") if not os.environ.get("ENABLE_SENSITIVE_DATA"): os.environ["ENABLE_SENSITIVE_DATA"] = "true" diff --git a/python/packages/devui/agent_framework_devui/_executor.py b/python/packages/devui/agent_framework_devui/_executor.py index 813bc4d4cc..be769cba09 100644 --- a/python/packages/devui/agent_framework_devui/_executor.py +++ b/python/packages/devui/agent_framework_devui/_executor.py @@ -82,27 +82,23 @@ def _setup_tracing_provider(self) -> None: def _setup_agent_framework_tracing(self) -> None: """Set up Agent Framework's built-in tracing.""" - # Configure Agent Framework tracing only if ENABLE_OTEL is set - if os.environ.get("ENABLE_OTEL"): + # Configure Agent Framework tracing only if ENABLE_INSTRUMENTATION is set + if os.environ.get("ENABLE_INSTRUMENTATION"): try: - from agent_framework.observability import OBSERVABILITY_SETTINGS, setup_observability + from agent_framework.observability import OBSERVABILITY_SETTINGS, configure_otel_providers # Only configure if not already executed if not OBSERVABILITY_SETTINGS._executed_setup: - # Get OTLP endpoint from either custom or standard env var - # This handles the case where env vars are set after ObservabilitySettings was imported - otlp_endpoint = os.environ.get("OTLP_ENDPOINT") or os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") - - # Pass the endpoint explicitly to setup_observability + # Run the configure_otel_providers # This ensures OTLP exporters are created even if env vars were set late - setup_observability(enable_sensitive_data=True, otlp_endpoint=otlp_endpoint) + configure_otel_providers(enable_sensitive_data=True) logger.info("Enabled Agent Framework observability") else: logger.debug("Agent Framework observability already configured") except Exception as e: logger.warning(f"Failed to enable Agent Framework observability: {e}") else: - logger.debug("ENABLE_OTEL not set, skipping observability setup") + logger.debug("ENABLE_INSTRUMENTATION not set, skipping observability setup") async def discover_entities(self) -> list[EntityInfo]: """Discover all available entities. diff --git a/python/packages/devui/agent_framework_devui/_server.py b/python/packages/devui/agent_framework_devui/_server.py index 26630945cb..b3a7c751b6 100644 --- a/python/packages/devui/agent_framework_devui/_server.py +++ b/python/packages/devui/agent_framework_devui/_server.py @@ -407,7 +407,7 @@ async def get_meta() -> MetaResponse: framework="agent_framework", runtime="python", # Python DevUI backend capabilities={ - "tracing": os.getenv("ENABLE_OTEL") == "true", + "tracing": os.getenv("ENABLE_INSTRUMENTATION") == "true", "openai_proxy": openai_executor.is_configured, "deployment": True, # Deployment feature is available }, diff --git a/python/packages/devui/agent_framework_devui/ui/assets/index.js b/python/packages/devui/agent_framework_devui/ui/assets/index.js index d12b71a838..1b05f27842 100644 --- a/python/packages/devui/agent_framework_devui/ui/assets/index.js +++ b/python/packages/devui/agent_framework_devui/ui/assets/index.js @@ -1,4 +1,4 @@ -function yE(e,n){for(var r=0;ra[l]})}}}return Object.freeze(Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}))}(function(){const n=document.createElement("link").relList;if(n&&n.supports&&n.supports("modulepreload"))return;for(const l of document.querySelectorAll('link[rel="modulepreload"]'))a(l);new MutationObserver(l=>{for(const c of l)if(c.type==="childList")for(const d of c.addedNodes)d.tagName==="LINK"&&d.rel==="modulepreload"&&a(d)}).observe(document,{childList:!0,subtree:!0});function r(l){const c={};return l.integrity&&(c.integrity=l.integrity),l.referrerPolicy&&(c.referrerPolicy=l.referrerPolicy),l.crossOrigin==="use-credentials"?c.credentials="include":l.crossOrigin==="anonymous"?c.credentials="omit":c.credentials="same-origin",c}function a(l){if(l.ep)return;l.ep=!0;const c=r(l);fetch(l.href,c)}})();function yp(e){return e&&e.__esModule&&Object.prototype.hasOwnProperty.call(e,"default")?e.default:e}var Gm={exports:{}},Bi={};/** +function yE(e, n) { for (var r = 0; r < n.length; r++) { const a = n[r]; if (typeof a != "string" && !Array.isArray(a)) { for (const l in a) if (l !== "default" && !(l in e)) { const c = Object.getOwnPropertyDescriptor(a, l); c && Object.defineProperty(e, l, c.get ? c : { enumerable: !0, get: () => a[l] }) } } } return Object.freeze(Object.defineProperty(e, Symbol.toStringTag, { value: "Module" })) } (function () { const n = document.createElement("link").relList; if (n && n.supports && n.supports("modulepreload")) return; for (const l of document.querySelectorAll('link[rel="modulepreload"]')) a(l); new MutationObserver(l => { for (const c of l) if (c.type === "childList") for (const d of c.addedNodes) d.tagName === "LINK" && d.rel === "modulepreload" && a(d) }).observe(document, { childList: !0, subtree: !0 }); function r(l) { const c = {}; return l.integrity && (c.integrity = l.integrity), l.referrerPolicy && (c.referrerPolicy = l.referrerPolicy), l.crossOrigin === "use-credentials" ? c.credentials = "include" : l.crossOrigin === "anonymous" ? c.credentials = "omit" : c.credentials = "same-origin", c } function a(l) { if (l.ep) return; l.ep = !0; const c = r(l); fetch(l.href, c) } })(); function yp(e) { return e && e.__esModule && Object.prototype.hasOwnProperty.call(e, "default") ? e.default : e } var Gm = { exports: {} }, Bi = {};/** * @license React * react-jsx-runtime.production.js * @@ -6,7 +6,7 @@ function yE(e,n){for(var r=0;r>>1,C=k[H];if(0>>1;H<$;){var Y=2*(H+1)-1,V=k[Y],W=Y+1,fe=k[W];if(0>l(V,I))Wl(fe,V)?(k[H]=fe,k[W]=I,H=W):(k[H]=V,k[Y]=I,H=Y);else if(Wl(fe,I))k[H]=fe,k[W]=I,H=W;else break e}}return L}function l(k,L){var I=k.sortIndex-L.sortIndex;return I!==0?I:k.id-L.id}if(e.unstable_now=void 0,typeof performance=="object"&&typeof performance.now=="function"){var c=performance;e.unstable_now=function(){return c.now()}}else{var d=Date,f=d.now();e.unstable_now=function(){return d.now()-f}}var m=[],h=[],g=1,y=null,x=3,b=!1,S=!1,N=!1,j=!1,_=typeof setTimeout=="function"?setTimeout:null,M=typeof clearTimeout=="function"?clearTimeout:null,E=typeof setImmediate<"u"?setImmediate:null;function T(k){for(var L=r(h);L!==null;){if(L.callback===null)a(h);else if(L.startTime<=k)a(h),L.sortIndex=L.expirationTime,n(m,L);else break;L=r(h)}}function R(k){if(N=!1,T(k),!S)if(r(m)!==null)S=!0,D||(D=!0,G());else{var L=r(h);L!==null&&U(R,L.startTime-k)}}var D=!1,O=-1,B=5,q=-1;function K(){return j?!0:!(e.unstable_now()-qk&&K());){var H=y.callback;if(typeof H=="function"){y.callback=null,x=y.priorityLevel;var C=H(y.expirationTime<=k);if(k=e.unstable_now(),typeof C=="function"){y.callback=C,T(k),L=!0;break t}y===r(m)&&a(m),T(k)}else a(m);y=r(m)}if(y!==null)L=!0;else{var $=r(h);$!==null&&U(R,$.startTime-k),L=!1}}break e}finally{y=null,x=I,b=!1}L=void 0}}finally{L?G():D=!1}}}var G;if(typeof E=="function")G=function(){E(J)};else if(typeof MessageChannel<"u"){var Z=new MessageChannel,P=Z.port2;Z.port1.onmessage=J,G=function(){P.postMessage(null)}}else G=function(){_(J,0)};function U(k,L){O=_(function(){k(e.unstable_now())},L)}e.unstable_IdlePriority=5,e.unstable_ImmediatePriority=1,e.unstable_LowPriority=4,e.unstable_NormalPriority=3,e.unstable_Profiling=null,e.unstable_UserBlockingPriority=2,e.unstable_cancelCallback=function(k){k.callback=null},e.unstable_forceFrameRate=function(k){0>k||125H?(k.sortIndex=I,n(h,k),r(m)===null&&k===r(h)&&(N?(M(O),O=-1):N=!0,U(R,I-H))):(k.sortIndex=C,n(m,k),S||b||(S=!0,D||(D=!0,G()))),k},e.unstable_shouldYield=K,e.unstable_wrapCallback=function(k){var L=x;return function(){var I=x;x=L;try{return k.apply(this,arguments)}finally{x=I}}}})(Km)),Km}var tv;function SE(){return tv||(tv=1,Wm.exports=NE()),Wm.exports}var Qm={exports:{}},Wt={};/** + */var ev; function NE() { return ev || (ev = 1, (function (e) { function n(k, L) { var I = k.length; k.push(L); e: for (; 0 < I;) { var H = I - 1 >>> 1, C = k[H]; if (0 < l(C, L)) k[H] = L, k[I] = C, I = H; else break e } } function r(k) { return k.length === 0 ? null : k[0] } function a(k) { if (k.length === 0) return null; var L = k[0], I = k.pop(); if (I !== L) { k[0] = I; e: for (var H = 0, C = k.length, $ = C >>> 1; H < $;) { var Y = 2 * (H + 1) - 1, V = k[Y], W = Y + 1, fe = k[W]; if (0 > l(V, I)) W < C && 0 > l(fe, V) ? (k[H] = fe, k[W] = I, H = W) : (k[H] = V, k[Y] = I, H = Y); else if (W < C && 0 > l(fe, I)) k[H] = fe, k[W] = I, H = W; else break e } } return L } function l(k, L) { var I = k.sortIndex - L.sortIndex; return I !== 0 ? I : k.id - L.id } if (e.unstable_now = void 0, typeof performance == "object" && typeof performance.now == "function") { var c = performance; e.unstable_now = function () { return c.now() } } else { var d = Date, f = d.now(); e.unstable_now = function () { return d.now() - f } } var m = [], h = [], g = 1, y = null, x = 3, b = !1, S = !1, N = !1, j = !1, _ = typeof setTimeout == "function" ? setTimeout : null, M = typeof clearTimeout == "function" ? clearTimeout : null, E = typeof setImmediate < "u" ? setImmediate : null; function T(k) { for (var L = r(h); L !== null;) { if (L.callback === null) a(h); else if (L.startTime <= k) a(h), L.sortIndex = L.expirationTime, n(m, L); else break; L = r(h) } } function R(k) { if (N = !1, T(k), !S) if (r(m) !== null) S = !0, D || (D = !0, G()); else { var L = r(h); L !== null && U(R, L.startTime - k) } } var D = !1, O = -1, B = 5, q = -1; function K() { return j ? !0 : !(e.unstable_now() - q < B) } function J() { if (j = !1, D) { var k = e.unstable_now(); q = k; var L = !0; try { e: { S = !1, N && (N = !1, M(O), O = -1), b = !0; var I = x; try { t: { for (T(k), y = r(m); y !== null && !(y.expirationTime > k && K());) { var H = y.callback; if (typeof H == "function") { y.callback = null, x = y.priorityLevel; var C = H(y.expirationTime <= k); if (k = e.unstable_now(), typeof C == "function") { y.callback = C, T(k), L = !0; break t } y === r(m) && a(m), T(k) } else a(m); y = r(m) } if (y !== null) L = !0; else { var $ = r(h); $ !== null && U(R, $.startTime - k), L = !1 } } break e } finally { y = null, x = I, b = !1 } L = void 0 } } finally { L ? G() : D = !1 } } } var G; if (typeof E == "function") G = function () { E(J) }; else if (typeof MessageChannel < "u") { var Z = new MessageChannel, P = Z.port2; Z.port1.onmessage = J, G = function () { P.postMessage(null) } } else G = function () { _(J, 0) }; function U(k, L) { O = _(function () { k(e.unstable_now()) }, L) } e.unstable_IdlePriority = 5, e.unstable_ImmediatePriority = 1, e.unstable_LowPriority = 4, e.unstable_NormalPriority = 3, e.unstable_Profiling = null, e.unstable_UserBlockingPriority = 2, e.unstable_cancelCallback = function (k) { k.callback = null }, e.unstable_forceFrameRate = function (k) { 0 > k || 125 < k ? console.error("forceFrameRate takes a positive int between 0 and 125, forcing frame rates higher than 125 fps is not supported") : B = 0 < k ? Math.floor(1e3 / k) : 5 }, e.unstable_getCurrentPriorityLevel = function () { return x }, e.unstable_next = function (k) { switch (x) { case 1: case 2: case 3: var L = 3; break; default: L = x }var I = x; x = L; try { return k() } finally { x = I } }, e.unstable_requestPaint = function () { j = !0 }, e.unstable_runWithPriority = function (k, L) { switch (k) { case 1: case 2: case 3: case 4: case 5: break; default: k = 3 }var I = x; x = k; try { return L() } finally { x = I } }, e.unstable_scheduleCallback = function (k, L, I) { var H = e.unstable_now(); switch (typeof I == "object" && I !== null ? (I = I.delay, I = typeof I == "number" && 0 < I ? H + I : H) : I = H, k) { case 1: var C = -1; break; case 2: C = 250; break; case 5: C = 1073741823; break; case 4: C = 1e4; break; default: C = 5e3 }return C = I + C, k = { id: g++, callback: L, priorityLevel: k, startTime: I, expirationTime: C, sortIndex: -1 }, I > H ? (k.sortIndex = I, n(h, k), r(m) === null && k === r(h) && (N ? (M(O), O = -1) : N = !0, U(R, I - H))) : (k.sortIndex = C, n(m, k), S || b || (S = !0, D || (D = !0, G()))), k }, e.unstable_shouldYield = K, e.unstable_wrapCallback = function (k) { var L = x; return function () { var I = x; x = L; try { return k.apply(this, arguments) } finally { x = I } } } })(Km)), Km } var tv; function SE() { return tv || (tv = 1, Wm.exports = NE()), Wm.exports } var Qm = { exports: {} }, Wt = {};/** * @license React * react-dom.production.js * @@ -30,7 +30,7 @@ function yE(e,n){for(var r=0;r"u"||typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.checkDCE!="function"))try{__REACT_DEVTOOLS_GLOBAL_HOOK__.checkDCE(e)}catch(n){console.error(n)}}return e(),Qm.exports=jE(),Qm.exports}/** + */var nv; function jE() { if (nv) return Wt; nv = 1; var e = pl(); function n(m) { var h = "https://react.dev/errors/" + m; if (1 < arguments.length) { h += "?args[]=" + encodeURIComponent(arguments[1]); for (var g = 2; g < arguments.length; g++)h += "&args[]=" + encodeURIComponent(arguments[g]) } return "Minified React error #" + m + "; visit " + h + " for the full message or use the non-minified dev environment for full errors and additional helpful warnings." } function r() { } var a = { d: { f: r, r: function () { throw Error(n(522)) }, D: r, C: r, L: r, m: r, X: r, S: r, M: r }, p: 0, findDOMNode: null }, l = Symbol.for("react.portal"); function c(m, h, g) { var y = 3 < arguments.length && arguments[3] !== void 0 ? arguments[3] : null; return { $$typeof: l, key: y == null ? null : "" + y, children: m, containerInfo: h, implementation: g } } var d = e.__CLIENT_INTERNALS_DO_NOT_USE_OR_WARN_USERS_THEY_CANNOT_UPGRADE; function f(m, h) { if (m === "font") return ""; if (typeof h == "string") return h === "use-credentials" ? h : "" } return Wt.__DOM_INTERNALS_DO_NOT_USE_OR_WARN_USERS_THEY_CANNOT_UPGRADE = a, Wt.createPortal = function (m, h) { var g = 2 < arguments.length && arguments[2] !== void 0 ? arguments[2] : null; if (!h || h.nodeType !== 1 && h.nodeType !== 9 && h.nodeType !== 11) throw Error(n(299)); return c(m, h, null, g) }, Wt.flushSync = function (m) { var h = d.T, g = a.p; try { if (d.T = null, a.p = 2, m) return m() } finally { d.T = h, a.p = g, a.d.f() } }, Wt.preconnect = function (m, h) { typeof m == "string" && (h ? (h = h.crossOrigin, h = typeof h == "string" ? h === "use-credentials" ? h : "" : void 0) : h = null, a.d.C(m, h)) }, Wt.prefetchDNS = function (m) { typeof m == "string" && a.d.D(m) }, Wt.preinit = function (m, h) { if (typeof m == "string" && h && typeof h.as == "string") { var g = h.as, y = f(g, h.crossOrigin), x = typeof h.integrity == "string" ? h.integrity : void 0, b = typeof h.fetchPriority == "string" ? h.fetchPriority : void 0; g === "style" ? a.d.S(m, typeof h.precedence == "string" ? h.precedence : void 0, { crossOrigin: y, integrity: x, fetchPriority: b }) : g === "script" && a.d.X(m, { crossOrigin: y, integrity: x, fetchPriority: b, nonce: typeof h.nonce == "string" ? h.nonce : void 0 }) } }, Wt.preinitModule = function (m, h) { if (typeof m == "string") if (typeof h == "object" && h !== null) { if (h.as == null || h.as === "script") { var g = f(h.as, h.crossOrigin); a.d.M(m, { crossOrigin: g, integrity: typeof h.integrity == "string" ? h.integrity : void 0, nonce: typeof h.nonce == "string" ? h.nonce : void 0 }) } } else h == null && a.d.M(m) }, Wt.preload = function (m, h) { if (typeof m == "string" && typeof h == "object" && h !== null && typeof h.as == "string") { var g = h.as, y = f(g, h.crossOrigin); a.d.L(m, g, { crossOrigin: y, integrity: typeof h.integrity == "string" ? h.integrity : void 0, nonce: typeof h.nonce == "string" ? h.nonce : void 0, type: typeof h.type == "string" ? h.type : void 0, fetchPriority: typeof h.fetchPriority == "string" ? h.fetchPriority : void 0, referrerPolicy: typeof h.referrerPolicy == "string" ? h.referrerPolicy : void 0, imageSrcSet: typeof h.imageSrcSet == "string" ? h.imageSrcSet : void 0, imageSizes: typeof h.imageSizes == "string" ? h.imageSizes : void 0, media: typeof h.media == "string" ? h.media : void 0 }) } }, Wt.preloadModule = function (m, h) { if (typeof m == "string") if (h) { var g = f(h.as, h.crossOrigin); a.d.m(m, { as: typeof h.as == "string" && h.as !== "script" ? h.as : void 0, crossOrigin: g, integrity: typeof h.integrity == "string" ? h.integrity : void 0 }) } else a.d.m(m) }, Wt.requestFormReset = function (m) { a.d.r(m) }, Wt.unstable_batchedUpdates = function (m, h) { return m(h) }, Wt.useFormState = function (m, h, g) { return d.H.useFormState(m, h, g) }, Wt.useFormStatus = function () { return d.H.useHostTransitionStatus() }, Wt.version = "19.1.1", Wt } var sv; function ew() { if (sv) return Qm.exports; sv = 1; function e() { if (!(typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ > "u" || typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.checkDCE != "function")) try { __REACT_DEVTOOLS_GLOBAL_HOOK__.checkDCE(e) } catch (n) { console.error(n) } } return e(), Qm.exports = jE(), Qm.exports }/** * @license React * react-dom-client.production.js * @@ -38,414 +38,475 @@ function yE(e,n){for(var r=0;rC||(t.current=H[C],H[C]=null,C--)}function V(t,s){C++,H[C]=t.current,t.current=s}var W=$(null),fe=$(null),ue=$(null),te=$(null);function ie(t,s){switch(V(ue,s),V(fe,t),V(W,null),s.nodeType){case 9:case 11:t=(t=s.documentElement)&&(t=t.namespaceURI)?jy(t):0;break;default:if(t=s.tagName,s=s.namespaceURI)s=jy(s),t=_y(s,t);else switch(t){case"svg":t=1;break;case"math":t=2;break;default:t=0}}Y(W),V(W,t)}function ge(){Y(W),Y(fe),Y(ue)}function be(t){t.memoizedState!==null&&V(te,t);var s=W.current,i=_y(s,t.type);s!==i&&(V(fe,t),V(W,i))}function we(t){fe.current===t&&(Y(W),Y(fe)),te.current===t&&(Y(te),zi._currentValue=I)}var ne=Object.prototype.hasOwnProperty,pe=e.unstable_scheduleCallback,he=e.unstable_cancelCallback,ee=e.unstable_shouldYield,ve=e.unstable_requestPaint,ye=e.unstable_now,Te=e.unstable_getCurrentPriorityLevel,je=e.unstable_ImmediatePriority,$e=e.unstable_UserBlockingPriority,it=e.unstable_NormalPriority,ze=e.unstable_LowPriority,Se=e.unstable_IdlePriority,Pe=e.log,Ee=e.unstable_setDisableYieldValue,He=null,Fe=null;function Nt(t){if(typeof Pe=="function"&&Ee(t),Fe&&typeof Fe.setStrictMode=="function")try{Fe.setStrictMode(He,t)}catch{}}var yt=Math.clz32?Math.clz32:xe,hs=Math.log,wo=Math.LN2;function xe(t){return t>>>=0,t===0?32:31-(hs(t)/wo|0)|0}var Re=256,Ue=4194304;function Et(t){var s=t&42;if(s!==0)return s;switch(t&-t){case 1:return 1;case 2:return 2;case 4:return 4;case 8:return 8;case 16:return 16;case 32:return 32;case 64:return 64;case 128:return 128;case 256:case 512:case 1024:case 2048:case 4096:case 8192:case 16384:case 32768:case 65536:case 131072:case 262144:case 524288:case 1048576:case 2097152:return t&4194048;case 4194304:case 8388608:case 16777216:case 33554432:return t&62914560;case 67108864:return 67108864;case 134217728:return 134217728;case 268435456:return 268435456;case 536870912:return 536870912;case 1073741824:return 0;default:return t}}function Dn(t,s,i){var u=t.pendingLanes;if(u===0)return 0;var p=0,v=t.suspendedLanes,A=t.pingedLanes;t=t.warmLanes;var z=u&134217727;return z!==0?(u=z&~v,u!==0?p=Et(u):(A&=z,A!==0?p=Et(A):i||(i=z&~t,i!==0&&(p=Et(i))))):(z=u&~v,z!==0?p=Et(z):A!==0?p=Et(A):i||(i=u&~t,i!==0&&(p=Et(i)))),p===0?0:s!==0&&s!==p&&(s&v)===0&&(v=p&-p,i=s&-s,v>=i||v===32&&(i&4194048)!==0)?s:p}function Le(t,s){return(t.pendingLanes&~(t.suspendedLanes&~t.pingedLanes)&s)===0}function Ne(t,s){switch(t){case 1:case 2:case 4:case 8:case 64:return s+250;case 16:case 32:case 128:case 256:case 512:case 1024:case 2048:case 4096:case 8192:case 16384:case 32768:case 65536:case 131072:case 262144:case 524288:case 1048576:case 2097152:return s+5e3;case 4194304:case 8388608:case 16777216:case 33554432:return-1;case 67108864:case 134217728:case 268435456:case 536870912:case 1073741824:return-1;default:return-1}}function lt(){var t=Re;return Re<<=1,(Re&4194048)===0&&(Re=256),t}function ot(){var t=Ue;return Ue<<=1,(Ue&62914560)===0&&(Ue=4194304),t}function At(t){for(var s=[],i=0;31>i;i++)s.push(t);return s}function en(t,s){t.pendingLanes|=s,s!==268435456&&(t.suspendedLanes=0,t.pingedLanes=0,t.warmLanes=0)}function On(t,s,i,u,p,v){var A=t.pendingLanes;t.pendingLanes=i,t.suspendedLanes=0,t.pingedLanes=0,t.warmLanes=0,t.expiredLanes&=i,t.entangledLanes&=i,t.errorRecoveryDisabledLanes&=i,t.shellSuspendCounter=0;var z=t.entanglements,F=t.expirationTimes,re=t.hiddenUpdates;for(i=A&~i;0)":-1p||F[u]!==re[p]){var le=` -`+F[u].replace(" at new "," at ");return t.displayName&&le.includes("")&&(le=le.replace("",t.displayName)),le}while(1<=u&&0<=p);break}}}finally{qa=!1,Error.prepareStackTrace=i}return(i=t?t.displayName||t.name:"")?ys(i):""}function Ud(t){switch(t.tag){case 26:case 27:case 5:return ys(t.type);case 16:return ys("Lazy");case 13:return ys("Suspense");case 19:return ys("SuspenseList");case 0:case 15:return Fa(t.type,!1);case 11:return Fa(t.type.render,!1);case 1:return Fa(t.type,!0);case 31:return ys("Activity");default:return""}}function Il(t){try{var s="";do s+=Ud(t),t=t.return;while(t);return s}catch(i){return` -Error generating stack: `+i.message+` -`+i.stack}}function tn(t){switch(typeof t){case"bigint":case"boolean":case"number":case"string":case"undefined":return t;case"object":return t;default:return""}}function Ll(t){var s=t.type;return(t=t.nodeName)&&t.toLowerCase()==="input"&&(s==="checkbox"||s==="radio")}function Vd(t){var s=Ll(t)?"checked":"value",i=Object.getOwnPropertyDescriptor(t.constructor.prototype,s),u=""+t[s];if(!t.hasOwnProperty(s)&&typeof i<"u"&&typeof i.get=="function"&&typeof i.set=="function"){var p=i.get,v=i.set;return Object.defineProperty(t,s,{configurable:!0,get:function(){return p.call(this)},set:function(A){u=""+A,v.call(this,A)}}),Object.defineProperty(t,s,{enumerable:i.enumerable}),{getValue:function(){return u},setValue:function(A){u=""+A},stopTracking:function(){t._valueTracker=null,delete t[s]}}}}function jo(t){t._valueTracker||(t._valueTracker=Vd(t))}function Ya(t){if(!t)return!1;var s=t._valueTracker;if(!s)return!0;var i=s.getValue(),u="";return t&&(u=Ll(t)?t.checked?"true":"false":t.value),t=u,t!==i?(s.setValue(t),!0):!1}function _o(t){if(t=t||(typeof document<"u"?document:void 0),typeof t>"u")return null;try{return t.activeElement||t.body}catch{return t.body}}var qd=/[\n"\\]/g;function nn(t){return t.replace(qd,function(s){return"\\"+s.charCodeAt(0).toString(16)+" "})}function zr(t,s,i,u,p,v,A,z){t.name="",A!=null&&typeof A!="function"&&typeof A!="symbol"&&typeof A!="boolean"?t.type=A:t.removeAttribute("type"),s!=null?A==="number"?(s===0&&t.value===""||t.value!=s)&&(t.value=""+tn(s)):t.value!==""+tn(s)&&(t.value=""+tn(s)):A!=="submit"&&A!=="reset"||t.removeAttribute("value"),s!=null?Ga(t,A,tn(s)):i!=null?Ga(t,A,tn(i)):u!=null&&t.removeAttribute("value"),p==null&&v!=null&&(t.defaultChecked=!!v),p!=null&&(t.checked=p&&typeof p!="function"&&typeof p!="symbol"),z!=null&&typeof z!="function"&&typeof z!="symbol"&&typeof z!="boolean"?t.name=""+tn(z):t.removeAttribute("name")}function Hl(t,s,i,u,p,v,A,z){if(v!=null&&typeof v!="function"&&typeof v!="symbol"&&typeof v!="boolean"&&(t.type=v),s!=null||i!=null){if(!(v!=="submit"&&v!=="reset"||s!=null))return;i=i!=null?""+tn(i):"",s=s!=null?""+tn(s):i,z||s===t.value||(t.value=s),t.defaultValue=s}u=u??p,u=typeof u!="function"&&typeof u!="symbol"&&!!u,t.checked=z?t.checked:!!u,t.defaultChecked=!!u,A!=null&&typeof A!="function"&&typeof A!="symbol"&&typeof A!="boolean"&&(t.name=A)}function Ga(t,s,i){s==="number"&&_o(t.ownerDocument)===t||t.defaultValue===""+i||(t.defaultValue=""+i)}function vs(t,s,i,u){if(t=t.options,s){s={};for(var p=0;p"u"||typeof window.document>"u"||typeof window.document.createElement>"u"),Zd=!1;if(bs)try{var Za={};Object.defineProperty(Za,"passive",{get:function(){Zd=!0}}),window.addEventListener("test",Za,Za),window.removeEventListener("test",Za,Za)}catch{Zd=!1}var Zs=null,Wd=null,Bl=null;function kg(){if(Bl)return Bl;var t,s=Wd,i=s.length,u,p="value"in Zs?Zs.value:Zs.textContent,v=p.length;for(t=0;t=Qa),Og=" ",zg=!1;function Ig(t,s){switch(t){case"keyup":return Uj.indexOf(s.keyCode)!==-1;case"keydown":return s.keyCode!==229;case"keypress":case"mousedown":case"focusout":return!0;default:return!1}}function Lg(t){return t=t.detail,typeof t=="object"&&"data"in t?t.data:null}var Ao=!1;function qj(t,s){switch(t){case"compositionend":return Lg(s);case"keypress":return s.which!==32?null:(zg=!0,Og);case"textInput":return t=s.data,t===Og&&zg?null:t;default:return null}}function Fj(t,s){if(Ao)return t==="compositionend"||!tf&&Ig(t,s)?(t=kg(),Bl=Wd=Zs=null,Ao=!1,t):null;switch(t){case"paste":return null;case"keypress":if(!(s.ctrlKey||s.altKey||s.metaKey)||s.ctrlKey&&s.altKey){if(s.char&&1=s)return{node:i,offset:s-t};t=u}e:{for(;i;){if(i.nextSibling){i=i.nextSibling;break e}i=i.parentNode}i=void 0}i=Fg(i)}}function Gg(t,s){return t&&s?t===s?!0:t&&t.nodeType===3?!1:s&&s.nodeType===3?Gg(t,s.parentNode):"contains"in t?t.contains(s):t.compareDocumentPosition?!!(t.compareDocumentPosition(s)&16):!1:!1}function Xg(t){t=t!=null&&t.ownerDocument!=null&&t.ownerDocument.defaultView!=null?t.ownerDocument.defaultView:window;for(var s=_o(t.document);s instanceof t.HTMLIFrameElement;){try{var i=typeof s.contentWindow.location.href=="string"}catch{i=!1}if(i)t=s.contentWindow;else break;s=_o(t.document)}return s}function rf(t){var s=t&&t.nodeName&&t.nodeName.toLowerCase();return s&&(s==="input"&&(t.type==="text"||t.type==="search"||t.type==="tel"||t.type==="url"||t.type==="password")||s==="textarea"||t.contentEditable==="true")}var Jj=bs&&"documentMode"in document&&11>=document.documentMode,Mo=null,of=null,ni=null,af=!1;function Zg(t,s,i){var u=i.window===i?i.document:i.nodeType===9?i:i.ownerDocument;af||Mo==null||Mo!==_o(u)||(u=Mo,"selectionStart"in u&&rf(u)?u={start:u.selectionStart,end:u.selectionEnd}:(u=(u.ownerDocument&&u.ownerDocument.defaultView||window).getSelection(),u={anchorNode:u.anchorNode,anchorOffset:u.anchorOffset,focusNode:u.focusNode,focusOffset:u.focusOffset}),ni&&ti(ni,u)||(ni=u,u=Mc(of,"onSelect"),0>=A,p-=A,Ns=1<<32-yt(s)+p|i<v?v:8;var A=k.T,z={};k.T=z,Yf(t,!1,s,i);try{var F=p(),re=k.S;if(re!==null&&re(z,F),F!==null&&typeof F=="object"&&typeof F.then=="function"){var le=l_(F,u);xi(t,s,le,hn(t))}else xi(t,s,u,hn(t))}catch(me){xi(t,s,{then:function(){},status:"rejected",reason:me},hn())}finally{L.p=v,k.T=A}}function m_(){}function qf(t,s,i,u){if(t.tag!==5)throw Error(a(476));var p=Wx(t).queue;Zx(t,p,s,I,i===null?m_:function(){return Kx(t),i(u)})}function Wx(t){var s=t.memoizedState;if(s!==null)return s;s={memoizedState:I,baseState:I,baseQueue:null,queue:{pending:null,lanes:0,dispatch:null,lastRenderedReducer:Es,lastRenderedState:I},next:null};var i={};return s.next={memoizedState:i,baseState:i,baseQueue:null,queue:{pending:null,lanes:0,dispatch:null,lastRenderedReducer:Es,lastRenderedState:i},next:null},t.memoizedState=s,t=t.alternate,t!==null&&(t.memoizedState=s),s}function Kx(t){var s=Wx(t).next.queue;xi(t,s,{},hn())}function Ff(){return Zt(zi)}function Qx(){return Rt().memoizedState}function Jx(){return Rt().memoizedState}function h_(t){for(var s=t.return;s!==null;){switch(s.tag){case 24:case 3:var i=hn();t=Qs(i);var u=Js(s,t,i);u!==null&&(pn(u,s,i),di(u,s,i)),s={cache:wf()},t.payload=s;return}s=s.return}}function p_(t,s,i){var u=hn();i={lane:u,revertLane:0,action:i,hasEagerState:!1,eagerState:null,next:null},uc(t)?t0(s,i):(i=df(t,s,i,u),i!==null&&(pn(i,t,u),n0(i,s,u)))}function e0(t,s,i){var u=hn();xi(t,s,i,u)}function xi(t,s,i,u){var p={lane:u,revertLane:0,action:i,hasEagerState:!1,eagerState:null,next:null};if(uc(t))t0(s,p);else{var v=t.alternate;if(t.lanes===0&&(v===null||v.lanes===0)&&(v=s.lastRenderedReducer,v!==null))try{var A=s.lastRenderedState,z=v(A,i);if(p.hasEagerState=!0,p.eagerState=z,cn(z,A))return Gl(t,s,p,0),gt===null&&Yl(),!1}catch{}finally{}if(i=df(t,s,p,u),i!==null)return pn(i,t,u),n0(i,s,u),!0}return!1}function Yf(t,s,i,u){if(u={lane:2,revertLane:jm(),action:u,hasEagerState:!1,eagerState:null,next:null},uc(t)){if(s)throw Error(a(479))}else s=df(t,i,u,2),s!==null&&pn(s,t,2)}function uc(t){var s=t.alternate;return t===Ze||s!==null&&s===Ze}function t0(t,s){Bo=rc=!0;var i=t.pending;i===null?s.next=s:(s.next=i.next,i.next=s),t.pending=s}function n0(t,s,i){if((i&4194048)!==0){var u=s.lanes;u&=t.pendingLanes,i|=u,s.lanes=i,La(t,i)}}var dc={readContext:Zt,use:ac,useCallback:Ct,useContext:Ct,useEffect:Ct,useImperativeHandle:Ct,useLayoutEffect:Ct,useInsertionEffect:Ct,useMemo:Ct,useReducer:Ct,useRef:Ct,useState:Ct,useDebugValue:Ct,useDeferredValue:Ct,useTransition:Ct,useSyncExternalStore:Ct,useId:Ct,useHostTransitionStatus:Ct,useFormState:Ct,useActionState:Ct,useOptimistic:Ct,useMemoCache:Ct,useCacheRefresh:Ct},s0={readContext:Zt,use:ac,useCallback:function(t,s){return rn().memoizedState=[t,s===void 0?null:s],t},useContext:Zt,useEffect:Bx,useImperativeHandle:function(t,s,i){i=i!=null?i.concat([t]):null,cc(4194308,4,qx.bind(null,s,t),i)},useLayoutEffect:function(t,s){return cc(4194308,4,t,s)},useInsertionEffect:function(t,s){cc(4,2,t,s)},useMemo:function(t,s){var i=rn();s=s===void 0?null:s;var u=t();if(Gr){Nt(!0);try{t()}finally{Nt(!1)}}return i.memoizedState=[u,s],u},useReducer:function(t,s,i){var u=rn();if(i!==void 0){var p=i(s);if(Gr){Nt(!0);try{i(s)}finally{Nt(!1)}}}else p=s;return u.memoizedState=u.baseState=p,t={pending:null,lanes:0,dispatch:null,lastRenderedReducer:t,lastRenderedState:p},u.queue=t,t=t.dispatch=p_.bind(null,Ze,t),[u.memoizedState,t]},useRef:function(t){var s=rn();return t={current:t},s.memoizedState=t},useState:function(t){t=Bf(t);var s=t.queue,i=e0.bind(null,Ze,s);return s.dispatch=i,[t.memoizedState,i]},useDebugValue:Uf,useDeferredValue:function(t,s){var i=rn();return Vf(i,t,s)},useTransition:function(){var t=Bf(!1);return t=Zx.bind(null,Ze,t.queue,!0,!1),rn().memoizedState=t,[!1,t]},useSyncExternalStore:function(t,s,i){var u=Ze,p=rn();if(at){if(i===void 0)throw Error(a(407));i=i()}else{if(i=s(),gt===null)throw Error(a(349));(nt&124)!==0||jx(u,s,i)}p.memoizedState=i;var v={value:i,getSnapshot:s};return p.queue=v,Bx(Ex.bind(null,u,v,t),[t]),u.flags|=2048,Uo(9,lc(),_x.bind(null,u,v,i,s),null),i},useId:function(){var t=rn(),s=gt.identifierPrefix;if(at){var i=Ss,u=Ns;i=(u&~(1<<32-yt(u)-1)).toString(32)+i,s="«"+s+"R"+i,i=oc++,0Ve?(Pt=Oe,Oe=null):Pt=Oe.sibling;var rt=oe(Q,Oe,se[Ve],ce);if(rt===null){Oe===null&&(Oe=Pt);break}t&&Oe&&rt.alternate===null&&s(Q,Oe),X=v(rt,X,Ve),Ke===null?Ce=rt:Ke.sibling=rt,Ke=rt,Oe=Pt}if(Ve===se.length)return i(Q,Oe),at&&Pr(Q,Ve),Ce;if(Oe===null){for(;VeVe?(Pt=Oe,Oe=null):Pt=Oe.sibling;var gr=oe(Q,Oe,rt.value,ce);if(gr===null){Oe===null&&(Oe=Pt);break}t&&Oe&&gr.alternate===null&&s(Q,Oe),X=v(gr,X,Ve),Ke===null?Ce=gr:Ke.sibling=gr,Ke=gr,Oe=Pt}if(rt.done)return i(Q,Oe),at&&Pr(Q,Ve),Ce;if(Oe===null){for(;!rt.done;Ve++,rt=se.next())rt=me(Q,rt.value,ce),rt!==null&&(X=v(rt,X,Ve),Ke===null?Ce=rt:Ke.sibling=rt,Ke=rt);return at&&Pr(Q,Ve),Ce}for(Oe=u(Oe);!rt.done;Ve++,rt=se.next())rt=ae(Oe,Q,Ve,rt.value,ce),rt!==null&&(t&&rt.alternate!==null&&Oe.delete(rt.key===null?Ve:rt.key),X=v(rt,X,Ve),Ke===null?Ce=rt:Ke.sibling=rt,Ke=rt);return t&&Oe.forEach(function(xE){return s(Q,xE)}),at&&Pr(Q,Ve),Ce}function mt(Q,X,se,ce){if(typeof se=="object"&&se!==null&&se.type===S&&se.key===null&&(se=se.props.children),typeof se=="object"&&se!==null){switch(se.$$typeof){case x:e:{for(var Ce=se.key;X!==null;){if(X.key===Ce){if(Ce=se.type,Ce===S){if(X.tag===7){i(Q,X.sibling),ce=p(X,se.props.children),ce.return=Q,Q=ce;break e}}else if(X.elementType===Ce||typeof Ce=="object"&&Ce!==null&&Ce.$$typeof===B&&o0(Ce)===X.type){i(Q,X.sibling),ce=p(X,se.props),vi(ce,se),ce.return=Q,Q=ce;break e}i(Q,X);break}else s(Q,X);X=X.sibling}se.type===S?(ce=$r(se.props.children,Q.mode,ce,se.key),ce.return=Q,Q=ce):(ce=Zl(se.type,se.key,se.props,null,Q.mode,ce),vi(ce,se),ce.return=Q,Q=ce)}return A(Q);case b:e:{for(Ce=se.key;X!==null;){if(X.key===Ce)if(X.tag===4&&X.stateNode.containerInfo===se.containerInfo&&X.stateNode.implementation===se.implementation){i(Q,X.sibling),ce=p(X,se.children||[]),ce.return=Q,Q=ce;break e}else{i(Q,X);break}else s(Q,X);X=X.sibling}ce=hf(se,Q.mode,ce),ce.return=Q,Q=ce}return A(Q);case B:return Ce=se._init,se=Ce(se._payload),mt(Q,X,se,ce)}if(U(se))return qe(Q,X,se,ce);if(G(se)){if(Ce=G(se),typeof Ce!="function")throw Error(a(150));return se=Ce.call(se),Be(Q,X,se,ce)}if(typeof se.then=="function")return mt(Q,X,fc(se),ce);if(se.$$typeof===E)return mt(Q,X,Jl(Q,se),ce);mc(Q,se)}return typeof se=="string"&&se!==""||typeof se=="number"||typeof se=="bigint"?(se=""+se,X!==null&&X.tag===6?(i(Q,X.sibling),ce=p(X,se),ce.return=Q,Q=ce):(i(Q,X),ce=mf(se,Q.mode,ce),ce.return=Q,Q=ce),A(Q)):i(Q,X)}return function(Q,X,se,ce){try{yi=0;var Ce=mt(Q,X,se,ce);return Vo=null,Ce}catch(Oe){if(Oe===ci||Oe===tc)throw Oe;var Ke=un(29,Oe,null,Q.mode);return Ke.lanes=ce,Ke.return=Q,Ke}finally{}}}var qo=a0(!0),i0=a0(!1),En=$(null),Wn=null;function tr(t){var s=t.alternate;V(zt,zt.current&1),V(En,t),Wn===null&&(s===null||$o.current!==null||s.memoizedState!==null)&&(Wn=t)}function l0(t){if(t.tag===22){if(V(zt,zt.current),V(En,t),Wn===null){var s=t.alternate;s!==null&&s.memoizedState!==null&&(Wn=t)}}else nr()}function nr(){V(zt,zt.current),V(En,En.current)}function Cs(t){Y(En),Wn===t&&(Wn=null),Y(zt)}var zt=$(0);function hc(t){for(var s=t;s!==null;){if(s.tag===13){var i=s.memoizedState;if(i!==null&&(i=i.dehydrated,i===null||i.data==="$?"||Im(i)))return s}else if(s.tag===19&&s.memoizedProps.revealOrder!==void 0){if((s.flags&128)!==0)return s}else if(s.child!==null){s.child.return=s,s=s.child;continue}if(s===t)break;for(;s.sibling===null;){if(s.return===null||s.return===t)return null;s=s.return}s.sibling.return=s.return,s=s.sibling}return null}function Gf(t,s,i,u){s=t.memoizedState,i=i(u,s),i=i==null?s:g({},s,i),t.memoizedState=i,t.lanes===0&&(t.updateQueue.baseState=i)}var Xf={enqueueSetState:function(t,s,i){t=t._reactInternals;var u=hn(),p=Qs(u);p.payload=s,i!=null&&(p.callback=i),s=Js(t,p,u),s!==null&&(pn(s,t,u),di(s,t,u))},enqueueReplaceState:function(t,s,i){t=t._reactInternals;var u=hn(),p=Qs(u);p.tag=1,p.payload=s,i!=null&&(p.callback=i),s=Js(t,p,u),s!==null&&(pn(s,t,u),di(s,t,u))},enqueueForceUpdate:function(t,s){t=t._reactInternals;var i=hn(),u=Qs(i);u.tag=2,s!=null&&(u.callback=s),s=Js(t,u,i),s!==null&&(pn(s,t,i),di(s,t,i))}};function c0(t,s,i,u,p,v,A){return t=t.stateNode,typeof t.shouldComponentUpdate=="function"?t.shouldComponentUpdate(u,v,A):s.prototype&&s.prototype.isPureReactComponent?!ti(i,u)||!ti(p,v):!0}function u0(t,s,i,u){t=s.state,typeof s.componentWillReceiveProps=="function"&&s.componentWillReceiveProps(i,u),typeof s.UNSAFE_componentWillReceiveProps=="function"&&s.UNSAFE_componentWillReceiveProps(i,u),s.state!==t&&Xf.enqueueReplaceState(s,s.state,null)}function Xr(t,s){var i=s;if("ref"in s){i={};for(var u in s)u!=="ref"&&(i[u]=s[u])}if(t=t.defaultProps){i===s&&(i=g({},i));for(var p in t)i[p]===void 0&&(i[p]=t[p])}return i}var pc=typeof reportError=="function"?reportError:function(t){if(typeof window=="object"&&typeof window.ErrorEvent=="function"){var s=new window.ErrorEvent("error",{bubbles:!0,cancelable:!0,message:typeof t=="object"&&t!==null&&typeof t.message=="string"?String(t.message):String(t),error:t});if(!window.dispatchEvent(s))return}else if(typeof process=="object"&&typeof process.emit=="function"){process.emit("uncaughtException",t);return}console.error(t)};function d0(t){pc(t)}function f0(t){console.error(t)}function m0(t){pc(t)}function gc(t,s){try{var i=t.onUncaughtError;i(s.value,{componentStack:s.stack})}catch(u){setTimeout(function(){throw u})}}function h0(t,s,i){try{var u=t.onCaughtError;u(i.value,{componentStack:i.stack,errorBoundary:s.tag===1?s.stateNode:null})}catch(p){setTimeout(function(){throw p})}}function Zf(t,s,i){return i=Qs(i),i.tag=3,i.payload={element:null},i.callback=function(){gc(t,s)},i}function p0(t){return t=Qs(t),t.tag=3,t}function g0(t,s,i,u){var p=i.type.getDerivedStateFromError;if(typeof p=="function"){var v=u.value;t.payload=function(){return p(v)},t.callback=function(){h0(s,i,u)}}var A=i.stateNode;A!==null&&typeof A.componentDidCatch=="function"&&(t.callback=function(){h0(s,i,u),typeof p!="function"&&(lr===null?lr=new Set([this]):lr.add(this));var z=u.stack;this.componentDidCatch(u.value,{componentStack:z!==null?z:""})})}function x_(t,s,i,u,p){if(i.flags|=32768,u!==null&&typeof u=="object"&&typeof u.then=="function"){if(s=i.alternate,s!==null&&ai(s,i,p,!0),i=En.current,i!==null){switch(i.tag){case 13:return Wn===null?vm():i.alternate===null&&_t===0&&(_t=3),i.flags&=-257,i.flags|=65536,i.lanes=p,u===jf?i.flags|=16384:(s=i.updateQueue,s===null?i.updateQueue=new Set([u]):s.add(u),wm(t,u,p)),!1;case 22:return i.flags|=65536,u===jf?i.flags|=16384:(s=i.updateQueue,s===null?(s={transitions:null,markerInstances:null,retryQueue:new Set([u])},i.updateQueue=s):(i=s.retryQueue,i===null?s.retryQueue=new Set([u]):i.add(u)),wm(t,u,p)),!1}throw Error(a(435,i.tag))}return wm(t,u,p),vm(),!1}if(at)return s=En.current,s!==null?((s.flags&65536)===0&&(s.flags|=256),s.flags|=65536,s.lanes=p,u!==xf&&(t=Error(a(422),{cause:u}),oi(Nn(t,i)))):(u!==xf&&(s=Error(a(423),{cause:u}),oi(Nn(s,i))),t=t.current.alternate,t.flags|=65536,p&=-p,t.lanes|=p,u=Nn(u,i),p=Zf(t.stateNode,u,p),Cf(t,p),_t!==4&&(_t=2)),!1;var v=Error(a(520),{cause:u});if(v=Nn(v,i),Ei===null?Ei=[v]:Ei.push(v),_t!==4&&(_t=2),s===null)return!0;u=Nn(u,i),i=s;do{switch(i.tag){case 3:return i.flags|=65536,t=p&-p,i.lanes|=t,t=Zf(i.stateNode,u,t),Cf(i,t),!1;case 1:if(s=i.type,v=i.stateNode,(i.flags&128)===0&&(typeof s.getDerivedStateFromError=="function"||v!==null&&typeof v.componentDidCatch=="function"&&(lr===null||!lr.has(v))))return i.flags|=65536,p&=-p,i.lanes|=p,p=p0(p),g0(p,t,i,u),Cf(i,p),!1}i=i.return}while(i!==null);return!1}var x0=Error(a(461)),$t=!1;function Vt(t,s,i,u){s.child=t===null?i0(s,null,i,u):qo(s,t.child,i,u)}function y0(t,s,i,u,p){i=i.render;var v=s.ref;if("ref"in u){var A={};for(var z in u)z!=="ref"&&(A[z]=u[z])}else A=u;return Fr(s),u=Rf(t,s,i,A,v,p),z=Df(),t!==null&&!$t?(Of(t,s,p),ks(t,s,p)):(at&&z&&pf(s),s.flags|=1,Vt(t,s,u,p),s.child)}function v0(t,s,i,u,p){if(t===null){var v=i.type;return typeof v=="function"&&!ff(v)&&v.defaultProps===void 0&&i.compare===null?(s.tag=15,s.type=v,b0(t,s,v,u,p)):(t=Zl(i.type,null,u,s,s.mode,p),t.ref=s.ref,t.return=s,s.child=t)}if(v=t.child,!sm(t,p)){var A=v.memoizedProps;if(i=i.compare,i=i!==null?i:ti,i(A,u)&&t.ref===s.ref)return ks(t,s,p)}return s.flags|=1,t=ws(v,u),t.ref=s.ref,t.return=s,s.child=t}function b0(t,s,i,u,p){if(t!==null){var v=t.memoizedProps;if(ti(v,u)&&t.ref===s.ref)if($t=!1,s.pendingProps=u=v,sm(t,p))(t.flags&131072)!==0&&($t=!0);else return s.lanes=t.lanes,ks(t,s,p)}return Wf(t,s,i,u,p)}function w0(t,s,i){var u=s.pendingProps,p=u.children,v=t!==null?t.memoizedState:null;if(u.mode==="hidden"){if((s.flags&128)!==0){if(u=v!==null?v.baseLanes|i:i,t!==null){for(p=s.child=t.child,v=0;p!==null;)v=v|p.lanes|p.childLanes,p=p.sibling;s.childLanes=v&~u}else s.childLanes=0,s.child=null;return N0(t,s,u,i)}if((i&536870912)!==0)s.memoizedState={baseLanes:0,cachePool:null},t!==null&&ec(s,v!==null?v.cachePool:null),v!==null?bx(s,v):Af(),l0(s);else return s.lanes=s.childLanes=536870912,N0(t,s,v!==null?v.baseLanes|i:i,i)}else v!==null?(ec(s,v.cachePool),bx(s,v),nr(),s.memoizedState=null):(t!==null&&ec(s,null),Af(),nr());return Vt(t,s,p,i),s.child}function N0(t,s,i,u){var p=Sf();return p=p===null?null:{parent:Ot._currentValue,pool:p},s.memoizedState={baseLanes:i,cachePool:p},t!==null&&ec(s,null),Af(),l0(s),t!==null&&ai(t,s,u,!0),null}function xc(t,s){var i=s.ref;if(i===null)t!==null&&t.ref!==null&&(s.flags|=4194816);else{if(typeof i!="function"&&typeof i!="object")throw Error(a(284));(t===null||t.ref!==i)&&(s.flags|=4194816)}}function Wf(t,s,i,u,p){return Fr(s),i=Rf(t,s,i,u,void 0,p),u=Df(),t!==null&&!$t?(Of(t,s,p),ks(t,s,p)):(at&&u&&pf(s),s.flags|=1,Vt(t,s,i,p),s.child)}function S0(t,s,i,u,p,v){return Fr(s),s.updateQueue=null,i=Nx(s,u,i,p),wx(t),u=Df(),t!==null&&!$t?(Of(t,s,v),ks(t,s,v)):(at&&u&&pf(s),s.flags|=1,Vt(t,s,i,v),s.child)}function j0(t,s,i,u,p){if(Fr(s),s.stateNode===null){var v=Oo,A=i.contextType;typeof A=="object"&&A!==null&&(v=Zt(A)),v=new i(u,v),s.memoizedState=v.state!==null&&v.state!==void 0?v.state:null,v.updater=Xf,s.stateNode=v,v._reactInternals=s,v=s.stateNode,v.props=u,v.state=s.memoizedState,v.refs={},_f(s),A=i.contextType,v.context=typeof A=="object"&&A!==null?Zt(A):Oo,v.state=s.memoizedState,A=i.getDerivedStateFromProps,typeof A=="function"&&(Gf(s,i,A,u),v.state=s.memoizedState),typeof i.getDerivedStateFromProps=="function"||typeof v.getSnapshotBeforeUpdate=="function"||typeof v.UNSAFE_componentWillMount!="function"&&typeof v.componentWillMount!="function"||(A=v.state,typeof v.componentWillMount=="function"&&v.componentWillMount(),typeof v.UNSAFE_componentWillMount=="function"&&v.UNSAFE_componentWillMount(),A!==v.state&&Xf.enqueueReplaceState(v,v.state,null),mi(s,u,v,p),fi(),v.state=s.memoizedState),typeof v.componentDidMount=="function"&&(s.flags|=4194308),u=!0}else if(t===null){v=s.stateNode;var z=s.memoizedProps,F=Xr(i,z);v.props=F;var re=v.context,le=i.contextType;A=Oo,typeof le=="object"&&le!==null&&(A=Zt(le));var me=i.getDerivedStateFromProps;le=typeof me=="function"||typeof v.getSnapshotBeforeUpdate=="function",z=s.pendingProps!==z,le||typeof v.UNSAFE_componentWillReceiveProps!="function"&&typeof v.componentWillReceiveProps!="function"||(z||re!==A)&&u0(s,v,u,A),Ks=!1;var oe=s.memoizedState;v.state=oe,mi(s,u,v,p),fi(),re=s.memoizedState,z||oe!==re||Ks?(typeof me=="function"&&(Gf(s,i,me,u),re=s.memoizedState),(F=Ks||c0(s,i,F,u,oe,re,A))?(le||typeof v.UNSAFE_componentWillMount!="function"&&typeof v.componentWillMount!="function"||(typeof v.componentWillMount=="function"&&v.componentWillMount(),typeof v.UNSAFE_componentWillMount=="function"&&v.UNSAFE_componentWillMount()),typeof v.componentDidMount=="function"&&(s.flags|=4194308)):(typeof v.componentDidMount=="function"&&(s.flags|=4194308),s.memoizedProps=u,s.memoizedState=re),v.props=u,v.state=re,v.context=A,u=F):(typeof v.componentDidMount=="function"&&(s.flags|=4194308),u=!1)}else{v=s.stateNode,Ef(t,s),A=s.memoizedProps,le=Xr(i,A),v.props=le,me=s.pendingProps,oe=v.context,re=i.contextType,F=Oo,typeof re=="object"&&re!==null&&(F=Zt(re)),z=i.getDerivedStateFromProps,(re=typeof z=="function"||typeof v.getSnapshotBeforeUpdate=="function")||typeof v.UNSAFE_componentWillReceiveProps!="function"&&typeof v.componentWillReceiveProps!="function"||(A!==me||oe!==F)&&u0(s,v,u,F),Ks=!1,oe=s.memoizedState,v.state=oe,mi(s,u,v,p),fi();var ae=s.memoizedState;A!==me||oe!==ae||Ks||t!==null&&t.dependencies!==null&&Ql(t.dependencies)?(typeof z=="function"&&(Gf(s,i,z,u),ae=s.memoizedState),(le=Ks||c0(s,i,le,u,oe,ae,F)||t!==null&&t.dependencies!==null&&Ql(t.dependencies))?(re||typeof v.UNSAFE_componentWillUpdate!="function"&&typeof v.componentWillUpdate!="function"||(typeof v.componentWillUpdate=="function"&&v.componentWillUpdate(u,ae,F),typeof v.UNSAFE_componentWillUpdate=="function"&&v.UNSAFE_componentWillUpdate(u,ae,F)),typeof v.componentDidUpdate=="function"&&(s.flags|=4),typeof v.getSnapshotBeforeUpdate=="function"&&(s.flags|=1024)):(typeof v.componentDidUpdate!="function"||A===t.memoizedProps&&oe===t.memoizedState||(s.flags|=4),typeof v.getSnapshotBeforeUpdate!="function"||A===t.memoizedProps&&oe===t.memoizedState||(s.flags|=1024),s.memoizedProps=u,s.memoizedState=ae),v.props=u,v.state=ae,v.context=F,u=le):(typeof v.componentDidUpdate!="function"||A===t.memoizedProps&&oe===t.memoizedState||(s.flags|=4),typeof v.getSnapshotBeforeUpdate!="function"||A===t.memoizedProps&&oe===t.memoizedState||(s.flags|=1024),u=!1)}return v=u,xc(t,s),u=(s.flags&128)!==0,v||u?(v=s.stateNode,i=u&&typeof i.getDerivedStateFromError!="function"?null:v.render(),s.flags|=1,t!==null&&u?(s.child=qo(s,t.child,null,p),s.child=qo(s,null,i,p)):Vt(t,s,i,p),s.memoizedState=v.state,t=s.child):t=ks(t,s,p),t}function _0(t,s,i,u){return ri(),s.flags|=256,Vt(t,s,i,u),s.child}var Kf={dehydrated:null,treeContext:null,retryLane:0,hydrationErrors:null};function Qf(t){return{baseLanes:t,cachePool:fx()}}function Jf(t,s,i){return t=t!==null?t.childLanes&~i:0,s&&(t|=Cn),t}function E0(t,s,i){var u=s.pendingProps,p=!1,v=(s.flags&128)!==0,A;if((A=v)||(A=t!==null&&t.memoizedState===null?!1:(zt.current&2)!==0),A&&(p=!0,s.flags&=-129),A=(s.flags&32)!==0,s.flags&=-33,t===null){if(at){if(p?tr(s):nr(),at){var z=jt,F;if(F=z){e:{for(F=z,z=Zn;F.nodeType!==8;){if(!z){z=null;break e}if(F=Hn(F.nextSibling),F===null){z=null;break e}}z=F}z!==null?(s.memoizedState={dehydrated:z,treeContext:Br!==null?{id:Ns,overflow:Ss}:null,retryLane:536870912,hydrationErrors:null},F=un(18,null,null,0),F.stateNode=z,F.return=s,s.child=F,Kt=s,jt=null,F=!0):F=!1}F||Vr(s)}if(z=s.memoizedState,z!==null&&(z=z.dehydrated,z!==null))return Im(z)?s.lanes=32:s.lanes=536870912,null;Cs(s)}return z=u.children,u=u.fallback,p?(nr(),p=s.mode,z=yc({mode:"hidden",children:z},p),u=$r(u,p,i,null),z.return=s,u.return=s,z.sibling=u,s.child=z,p=s.child,p.memoizedState=Qf(i),p.childLanes=Jf(t,A,i),s.memoizedState=Kf,u):(tr(s),em(s,z))}if(F=t.memoizedState,F!==null&&(z=F.dehydrated,z!==null)){if(v)s.flags&256?(tr(s),s.flags&=-257,s=tm(t,s,i)):s.memoizedState!==null?(nr(),s.child=t.child,s.flags|=128,s=null):(nr(),p=u.fallback,z=s.mode,u=yc({mode:"visible",children:u.children},z),p=$r(p,z,i,null),p.flags|=2,u.return=s,p.return=s,u.sibling=p,s.child=u,qo(s,t.child,null,i),u=s.child,u.memoizedState=Qf(i),u.childLanes=Jf(t,A,i),s.memoizedState=Kf,s=p);else if(tr(s),Im(z)){if(A=z.nextSibling&&z.nextSibling.dataset,A)var re=A.dgst;A=re,u=Error(a(419)),u.stack="",u.digest=A,oi({value:u,source:null,stack:null}),s=tm(t,s,i)}else if($t||ai(t,s,i,!1),A=(i&t.childLanes)!==0,$t||A){if(A=gt,A!==null&&(u=i&-i,u=(u&42)!==0?1:Ha(u),u=(u&(A.suspendedLanes|i))!==0?0:u,u!==0&&u!==F.retryLane))throw F.retryLane=u,Do(t,u),pn(A,t,u),x0;z.data==="$?"||vm(),s=tm(t,s,i)}else z.data==="$?"?(s.flags|=192,s.child=t.child,s=null):(t=F.treeContext,jt=Hn(z.nextSibling),Kt=s,at=!0,Ur=null,Zn=!1,t!==null&&(jn[_n++]=Ns,jn[_n++]=Ss,jn[_n++]=Br,Ns=t.id,Ss=t.overflow,Br=s),s=em(s,u.children),s.flags|=4096);return s}return p?(nr(),p=u.fallback,z=s.mode,F=t.child,re=F.sibling,u=ws(F,{mode:"hidden",children:u.children}),u.subtreeFlags=F.subtreeFlags&65011712,re!==null?p=ws(re,p):(p=$r(p,z,i,null),p.flags|=2),p.return=s,u.return=s,u.sibling=p,s.child=u,u=p,p=s.child,z=t.child.memoizedState,z===null?z=Qf(i):(F=z.cachePool,F!==null?(re=Ot._currentValue,F=F.parent!==re?{parent:re,pool:re}:F):F=fx(),z={baseLanes:z.baseLanes|i,cachePool:F}),p.memoizedState=z,p.childLanes=Jf(t,A,i),s.memoizedState=Kf,u):(tr(s),i=t.child,t=i.sibling,i=ws(i,{mode:"visible",children:u.children}),i.return=s,i.sibling=null,t!==null&&(A=s.deletions,A===null?(s.deletions=[t],s.flags|=16):A.push(t)),s.child=i,s.memoizedState=null,i)}function em(t,s){return s=yc({mode:"visible",children:s},t.mode),s.return=t,t.child=s}function yc(t,s){return t=un(22,t,null,s),t.lanes=0,t.stateNode={_visibility:1,_pendingMarkers:null,_retryCache:null,_transitions:null},t}function tm(t,s,i){return qo(s,t.child,null,i),t=em(s,s.pendingProps.children),t.flags|=2,s.memoizedState=null,t}function C0(t,s,i){t.lanes|=s;var u=t.alternate;u!==null&&(u.lanes|=s),vf(t.return,s,i)}function nm(t,s,i,u,p){var v=t.memoizedState;v===null?t.memoizedState={isBackwards:s,rendering:null,renderingStartTime:0,last:u,tail:i,tailMode:p}:(v.isBackwards=s,v.rendering=null,v.renderingStartTime=0,v.last=u,v.tail=i,v.tailMode=p)}function k0(t,s,i){var u=s.pendingProps,p=u.revealOrder,v=u.tail;if(Vt(t,s,u.children,i),u=zt.current,(u&2)!==0)u=u&1|2,s.flags|=128;else{if(t!==null&&(t.flags&128)!==0)e:for(t=s.child;t!==null;){if(t.tag===13)t.memoizedState!==null&&C0(t,i,s);else if(t.tag===19)C0(t,i,s);else if(t.child!==null){t.child.return=t,t=t.child;continue}if(t===s)break e;for(;t.sibling===null;){if(t.return===null||t.return===s)break e;t=t.return}t.sibling.return=t.return,t=t.sibling}u&=1}switch(V(zt,u),p){case"forwards":for(i=s.child,p=null;i!==null;)t=i.alternate,t!==null&&hc(t)===null&&(p=i),i=i.sibling;i=p,i===null?(p=s.child,s.child=null):(p=i.sibling,i.sibling=null),nm(s,!1,p,i,v);break;case"backwards":for(i=null,p=s.child,s.child=null;p!==null;){if(t=p.alternate,t!==null&&hc(t)===null){s.child=p;break}t=p.sibling,p.sibling=i,i=p,p=t}nm(s,!0,i,null,v);break;case"together":nm(s,!1,null,null,void 0);break;default:s.memoizedState=null}return s.child}function ks(t,s,i){if(t!==null&&(s.dependencies=t.dependencies),ir|=s.lanes,(i&s.childLanes)===0)if(t!==null){if(ai(t,s,i,!1),(i&s.childLanes)===0)return null}else return null;if(t!==null&&s.child!==t.child)throw Error(a(153));if(s.child!==null){for(t=s.child,i=ws(t,t.pendingProps),s.child=i,i.return=s;t.sibling!==null;)t=t.sibling,i=i.sibling=ws(t,t.pendingProps),i.return=s;i.sibling=null}return s.child}function sm(t,s){return(t.lanes&s)!==0?!0:(t=t.dependencies,!!(t!==null&&Ql(t)))}function y_(t,s,i){switch(s.tag){case 3:ie(s,s.stateNode.containerInfo),Ws(s,Ot,t.memoizedState.cache),ri();break;case 27:case 5:be(s);break;case 4:ie(s,s.stateNode.containerInfo);break;case 10:Ws(s,s.type,s.memoizedProps.value);break;case 13:var u=s.memoizedState;if(u!==null)return u.dehydrated!==null?(tr(s),s.flags|=128,null):(i&s.child.childLanes)!==0?E0(t,s,i):(tr(s),t=ks(t,s,i),t!==null?t.sibling:null);tr(s);break;case 19:var p=(t.flags&128)!==0;if(u=(i&s.childLanes)!==0,u||(ai(t,s,i,!1),u=(i&s.childLanes)!==0),p){if(u)return k0(t,s,i);s.flags|=128}if(p=s.memoizedState,p!==null&&(p.rendering=null,p.tail=null,p.lastEffect=null),V(zt,zt.current),u)break;return null;case 22:case 23:return s.lanes=0,w0(t,s,i);case 24:Ws(s,Ot,t.memoizedState.cache)}return ks(t,s,i)}function A0(t,s,i){if(t!==null)if(t.memoizedProps!==s.pendingProps)$t=!0;else{if(!sm(t,i)&&(s.flags&128)===0)return $t=!1,y_(t,s,i);$t=(t.flags&131072)!==0}else $t=!1,at&&(s.flags&1048576)!==0&&ox(s,Kl,s.index);switch(s.lanes=0,s.tag){case 16:e:{t=s.pendingProps;var u=s.elementType,p=u._init;if(u=p(u._payload),s.type=u,typeof u=="function")ff(u)?(t=Xr(u,t),s.tag=1,s=j0(null,s,u,t,i)):(s.tag=0,s=Wf(null,s,u,t,i));else{if(u!=null){if(p=u.$$typeof,p===T){s.tag=11,s=y0(null,s,u,t,i);break e}else if(p===O){s.tag=14,s=v0(null,s,u,t,i);break e}}throw s=P(u)||u,Error(a(306,s,""))}}return s;case 0:return Wf(t,s,s.type,s.pendingProps,i);case 1:return u=s.type,p=Xr(u,s.pendingProps),j0(t,s,u,p,i);case 3:e:{if(ie(s,s.stateNode.containerInfo),t===null)throw Error(a(387));u=s.pendingProps;var v=s.memoizedState;p=v.element,Ef(t,s),mi(s,u,null,i);var A=s.memoizedState;if(u=A.cache,Ws(s,Ot,u),u!==v.cache&&bf(s,[Ot],i,!0),fi(),u=A.element,v.isDehydrated)if(v={element:u,isDehydrated:!1,cache:A.cache},s.updateQueue.baseState=v,s.memoizedState=v,s.flags&256){s=_0(t,s,u,i);break e}else if(u!==p){p=Nn(Error(a(424)),s),oi(p),s=_0(t,s,u,i);break e}else{switch(t=s.stateNode.containerInfo,t.nodeType){case 9:t=t.body;break;default:t=t.nodeName==="HTML"?t.ownerDocument.body:t}for(jt=Hn(t.firstChild),Kt=s,at=!0,Ur=null,Zn=!0,i=i0(s,null,u,i),s.child=i;i;)i.flags=i.flags&-3|4096,i=i.sibling}else{if(ri(),u===p){s=ks(t,s,i);break e}Vt(t,s,u,i)}s=s.child}return s;case 26:return xc(t,s),t===null?(i=Dy(s.type,null,s.pendingProps,null))?s.memoizedState=i:at||(i=s.type,t=s.pendingProps,u=Rc(ue.current).createElement(i),u[Ht]=s,u[Xt]=t,Ft(u,i,t),Mt(u),s.stateNode=u):s.memoizedState=Dy(s.type,t.memoizedProps,s.pendingProps,t.memoizedState),null;case 27:return be(s),t===null&&at&&(u=s.stateNode=My(s.type,s.pendingProps,ue.current),Kt=s,Zn=!0,p=jt,dr(s.type)?(Lm=p,jt=Hn(u.firstChild)):jt=p),Vt(t,s,s.pendingProps.children,i),xc(t,s),t===null&&(s.flags|=4194304),s.child;case 5:return t===null&&at&&((p=u=jt)&&(u=Y_(u,s.type,s.pendingProps,Zn),u!==null?(s.stateNode=u,Kt=s,jt=Hn(u.firstChild),Zn=!1,p=!0):p=!1),p||Vr(s)),be(s),p=s.type,v=s.pendingProps,A=t!==null?t.memoizedProps:null,u=v.children,Dm(p,v)?u=null:A!==null&&Dm(p,A)&&(s.flags|=32),s.memoizedState!==null&&(p=Rf(t,s,u_,null,null,i),zi._currentValue=p),xc(t,s),Vt(t,s,u,i),s.child;case 6:return t===null&&at&&((t=i=jt)&&(i=G_(i,s.pendingProps,Zn),i!==null?(s.stateNode=i,Kt=s,jt=null,t=!0):t=!1),t||Vr(s)),null;case 13:return E0(t,s,i);case 4:return ie(s,s.stateNode.containerInfo),u=s.pendingProps,t===null?s.child=qo(s,null,u,i):Vt(t,s,u,i),s.child;case 11:return y0(t,s,s.type,s.pendingProps,i);case 7:return Vt(t,s,s.pendingProps,i),s.child;case 8:return Vt(t,s,s.pendingProps.children,i),s.child;case 12:return Vt(t,s,s.pendingProps.children,i),s.child;case 10:return u=s.pendingProps,Ws(s,s.type,u.value),Vt(t,s,u.children,i),s.child;case 9:return p=s.type._context,u=s.pendingProps.children,Fr(s),p=Zt(p),u=u(p),s.flags|=1,Vt(t,s,u,i),s.child;case 14:return v0(t,s,s.type,s.pendingProps,i);case 15:return b0(t,s,s.type,s.pendingProps,i);case 19:return k0(t,s,i);case 31:return u=s.pendingProps,i=s.mode,u={mode:u.mode,children:u.children},t===null?(i=yc(u,i),i.ref=s.ref,s.child=i,i.return=s,s=i):(i=ws(t.child,u),i.ref=s.ref,s.child=i,i.return=s,s=i),s;case 22:return w0(t,s,i);case 24:return Fr(s),u=Zt(Ot),t===null?(p=Sf(),p===null&&(p=gt,v=wf(),p.pooledCache=v,v.refCount++,v!==null&&(p.pooledCacheLanes|=i),p=v),s.memoizedState={parent:u,cache:p},_f(s),Ws(s,Ot,p)):((t.lanes&i)!==0&&(Ef(t,s),mi(s,null,null,i),fi()),p=t.memoizedState,v=s.memoizedState,p.parent!==u?(p={parent:u,cache:u},s.memoizedState=p,s.lanes===0&&(s.memoizedState=s.updateQueue.baseState=p),Ws(s,Ot,u)):(u=v.cache,Ws(s,Ot,u),u!==p.cache&&bf(s,[Ot],i,!0))),Vt(t,s,s.pendingProps.children,i),s.child;case 29:throw s.pendingProps}throw Error(a(156,s.tag))}function As(t){t.flags|=4}function M0(t,s){if(s.type!=="stylesheet"||(s.state.loading&4)!==0)t.flags&=-16777217;else if(t.flags|=16777216,!Hy(s)){if(s=En.current,s!==null&&((nt&4194048)===nt?Wn!==null:(nt&62914560)!==nt&&(nt&536870912)===0||s!==Wn))throw ui=jf,mx;t.flags|=8192}}function vc(t,s){s!==null&&(t.flags|=4),t.flags&16384&&(s=t.tag!==22?ot():536870912,t.lanes|=s,Xo|=s)}function bi(t,s){if(!at)switch(t.tailMode){case"hidden":s=t.tail;for(var i=null;s!==null;)s.alternate!==null&&(i=s),s=s.sibling;i===null?t.tail=null:i.sibling=null;break;case"collapsed":i=t.tail;for(var u=null;i!==null;)i.alternate!==null&&(u=i),i=i.sibling;u===null?s||t.tail===null?t.tail=null:t.tail.sibling=null:u.sibling=null}}function St(t){var s=t.alternate!==null&&t.alternate.child===t.child,i=0,u=0;if(s)for(var p=t.child;p!==null;)i|=p.lanes|p.childLanes,u|=p.subtreeFlags&65011712,u|=p.flags&65011712,p.return=t,p=p.sibling;else for(p=t.child;p!==null;)i|=p.lanes|p.childLanes,u|=p.subtreeFlags,u|=p.flags,p.return=t,p=p.sibling;return t.subtreeFlags|=u,t.childLanes=i,s}function v_(t,s,i){var u=s.pendingProps;switch(gf(s),s.tag){case 31:case 16:case 15:case 0:case 11:case 7:case 8:case 12:case 9:case 14:return St(s),null;case 1:return St(s),null;case 3:return i=s.stateNode,u=null,t!==null&&(u=t.memoizedState.cache),s.memoizedState.cache!==u&&(s.flags|=2048),_s(Ot),ge(),i.pendingContext&&(i.context=i.pendingContext,i.pendingContext=null),(t===null||t.child===null)&&(si(s)?As(s):t===null||t.memoizedState.isDehydrated&&(s.flags&256)===0||(s.flags|=1024,lx())),St(s),null;case 26:return i=s.memoizedState,t===null?(As(s),i!==null?(St(s),M0(s,i)):(St(s),s.flags&=-16777217)):i?i!==t.memoizedState?(As(s),St(s),M0(s,i)):(St(s),s.flags&=-16777217):(t.memoizedProps!==u&&As(s),St(s),s.flags&=-16777217),null;case 27:we(s),i=ue.current;var p=s.type;if(t!==null&&s.stateNode!=null)t.memoizedProps!==u&&As(s);else{if(!u){if(s.stateNode===null)throw Error(a(166));return St(s),null}t=W.current,si(s)?ax(s):(t=My(p,u,i),s.stateNode=t,As(s))}return St(s),null;case 5:if(we(s),i=s.type,t!==null&&s.stateNode!=null)t.memoizedProps!==u&&As(s);else{if(!u){if(s.stateNode===null)throw Error(a(166));return St(s),null}if(t=W.current,si(s))ax(s);else{switch(p=Rc(ue.current),t){case 1:t=p.createElementNS("http://www.w3.org/2000/svg",i);break;case 2:t=p.createElementNS("http://www.w3.org/1998/Math/MathML",i);break;default:switch(i){case"svg":t=p.createElementNS("http://www.w3.org/2000/svg",i);break;case"math":t=p.createElementNS("http://www.w3.org/1998/Math/MathML",i);break;case"script":t=p.createElement("div"),t.innerHTML="