diff --git a/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/__init__.py b/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/__init__.py index fdc17a34..e19cf8e9 100644 --- a/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/__init__.py +++ b/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/__init__.py @@ -12,6 +12,12 @@ ) from .execute_tool_scope import ExecuteToolScope from .execution_type import ExecutionType +from .exporters.enriched_span import EnrichedReadableSpan +from .exporters.enriching_span_processor import ( + get_span_enricher, + register_span_enricher, + unregister_span_enricher, +) from .inference_call_details import InferenceCallDetails from .inference_operation_type import InferenceOperationType from .inference_scope import InferenceScope @@ -32,6 +38,11 @@ "is_configured", "get_tracer", "get_tracer_provider", + # Span enrichment + "register_span_enricher", + "unregister_span_enricher", + "get_span_enricher", + "EnrichedReadableSpan", # Span processor "SpanProcessor", # Base scope class diff --git a/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/config.py b/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/config.py index 26eab5d2..6c3befdf 100644 --- a/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/config.py +++ b/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/config.py @@ -9,10 +9,13 @@ from opentelemetry import trace from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, Resource from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter +from opentelemetry.sdk.trace.export import ConsoleSpanExporter from .exporters.agent365_exporter import _Agent365Exporter from .exporters.agent365_exporter_options import Agent365ExporterOptions +from .exporters.enriching_span_processor import ( + _EnrichingBatchSpanProcessor, +) from .exporters.utils import is_agent365_exporter_enabled from .trace_processor.span_processor import SpanProcessor @@ -166,8 +169,9 @@ def _configure_internal( # Add span processors - # Create BatchSpanProcessor with optimized settings - batch_processor = BatchSpanProcessor(exporter, **batch_processor_kwargs) + # Create _EnrichingBatchSpanProcessor with optimized settings + # This allows extensions to enrich spans before export + batch_processor = _EnrichingBatchSpanProcessor(exporter, **batch_processor_kwargs) agent_processor = SpanProcessor() tracer_provider.add_span_processor(batch_processor) diff --git a/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/exporters/enriched_span.py b/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/exporters/enriched_span.py new file mode 100644 index 00000000..b57bd4e7 --- /dev/null +++ b/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/exporters/enriched_span.py @@ -0,0 +1,160 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Enriched ReadableSpan wrapper for adding attributes to immutable spans.""" + +import json +from typing import Any + +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.util import types + + +class EnrichedReadableSpan(ReadableSpan): + """ + Wrapper to add attributes to an immutable ReadableSpan. + + Since ReadableSpan is immutable after a span ends, this wrapper allows + extensions to add additional attributes before export without modifying + the original span. + """ + + def __init__(self, span: ReadableSpan, extra_attributes: dict): + """ + Initialize the enriched span wrapper. + + Args: + span: The original ReadableSpan to wrap. + extra_attributes: Additional attributes to merge with the original. + """ + self._span = span + self._extra_attributes = extra_attributes + + @property + def attributes(self) -> types.Attributes: + """Return merged attributes from original span and extra attributes.""" + original = dict(self._span.attributes or {}) + original.update(self._extra_attributes) + return original + + @property + def name(self): + """Return the span name.""" + return self._span.name + + @property + def context(self): + """Return the span context.""" + return self._span.context + + @property + def parent(self): + """Return the parent span context.""" + return self._span.parent + + @property + def start_time(self): + """Return the span start time.""" + return self._span.start_time + + @property + def end_time(self): + """Return the span end time.""" + return self._span.end_time + + @property + def status(self): + """Return the span status.""" + return self._span.status + + @property + def kind(self): + """Return the span kind.""" + return self._span.kind + + @property + def events(self): + """Return the span events.""" + return self._span.events + + @property + def links(self): + """Return the span links.""" + return self._span.links + + @property + def resource(self): + """Return the span resource.""" + return self._span.resource + + @property + def instrumentation_scope(self): + """Return the instrumentation scope.""" + return self._span.instrumentation_scope + + def to_json(self, indent: int | None = 4) -> str: + """ + Convert span to JSON string with enriched attributes. + + Args: + indent: JSON indentation level. + + Returns: + JSON string representation of the span. + """ + # Build the JSON dict manually to include enriched attributes + return json.dumps( + { + "name": self.name, + "context": { + "trace_id": f"0x{self.context.trace_id:032x}", + "span_id": f"0x{self.context.span_id:016x}", + "trace_state": str(self.context.trace_state), + } + if self.context + else None, + "kind": str(self.kind), + "parent_id": f"0x{self.parent.span_id:016x}" if self.parent else None, + "start_time": self._format_time(self.start_time), + "end_time": self._format_time(self.end_time), + "status": { + "status_code": str(self.status.status_code), + "description": self.status.description, + } + if self.status + else None, + "attributes": dict(self.attributes) if self.attributes else None, + "events": [self._format_event(e) for e in self.events] if self.events else None, + "links": [self._format_link(lnk) for lnk in self.links] if self.links else None, + "resource": dict(self.resource.attributes) if self.resource else None, + }, + indent=indent, + ) + + def _format_time(self, time_ns: int | None) -> str | None: + """Format nanosecond timestamp to ISO string.""" + if time_ns is None: + return None + from datetime import datetime, timezone + + return datetime.fromtimestamp(time_ns / 1e9, tz=timezone.utc).isoformat() + + def _format_event(self, event: Any) -> dict: + """Format a span event.""" + return { + "name": event.name, + "timestamp": self._format_time(event.timestamp), + "attributes": dict(event.attributes) if event.attributes else None, + } + + def _format_link(self, link: Any) -> dict: + """Format a span link.""" + return { + "context": { + "trace_id": f"0x{link.context.trace_id:032x}", + "span_id": f"0x{link.context.span_id:016x}", + } + if link.context + else None, + "attributes": dict(link.attributes) if link.attributes else None, + } diff --git a/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/exporters/enriching_span_processor.py b/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/exporters/enriching_span_processor.py new file mode 100644 index 00000000..03c54775 --- /dev/null +++ b/libraries/microsoft-agents-a365-observability-core/microsoft_agents_a365/observability/core/exporters/enriching_span_processor.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Span enrichment support for the Agent365 exporter pipeline.""" + +import logging +import threading +from collections.abc import Callable + +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import BatchSpanProcessor + +logger = logging.getLogger(__name__) + +# Single span enricher - only one platform instrumentor should be active at a time +_span_enricher: Callable[[ReadableSpan], ReadableSpan] | None = None +_enricher_lock = threading.Lock() + + +def register_span_enricher(enricher: Callable[[ReadableSpan], ReadableSpan]) -> None: + """Register the span enricher for the active platform instrumentor. + + Only one enricher can be registered at a time since auto-instrumentation + is platform-specific (Semantic Kernel, LangChain, or OpenAI Agents). + + Args: + enricher: Function that takes a ReadableSpan and returns an enriched span. + + Raises: + RuntimeError: If an enricher is already registered. + """ + global _span_enricher + with _enricher_lock: + if _span_enricher is not None: + raise RuntimeError( + "A span enricher is already registered. " + "Only one platform instrumentor can be active at a time." + ) + _span_enricher = enricher + logger.debug("Span enricher registered: %s", enricher.__name__) + + +def unregister_span_enricher() -> None: + """Unregister the current span enricher. + + Called during uninstrumentation to clean up. + """ + global _span_enricher + with _enricher_lock: + if _span_enricher is not None: + logger.debug("Span enricher unregistered: %s", _span_enricher.__name__) + _span_enricher = None + + +def get_span_enricher() -> Callable[[ReadableSpan], ReadableSpan] | None: + """Get the currently registered span enricher. + + Returns: + The registered enricher function, or None if no enricher is registered. + """ + with _enricher_lock: + return _span_enricher + + +class _EnrichingBatchSpanProcessor(BatchSpanProcessor): + """BatchSpanProcessor that applies the registered enricher before batching.""" + + def on_end(self, span: ReadableSpan) -> None: + """Apply the span enricher and pass to parent for batching. + + Args: + span: The span that has ended. + """ + enriched_span = span + + enricher = get_span_enricher() + if enricher is not None: + try: + enriched_span = enricher(span) + except Exception: + logger.exception( + "Span enricher %s raised an exception, using original span", + enricher.__name__, + ) + + super().on_end(enriched_span) diff --git a/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/span_enricher.py b/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/span_enricher.py new file mode 100644 index 00000000..83f14213 --- /dev/null +++ b/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/span_enricher.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Span enricher for Semantic Kernel.""" + +from microsoft_agents_a365.observability.core.constants import ( + EXECUTE_TOOL_OPERATION_NAME, + GEN_AI_INPUT_MESSAGES_KEY, + GEN_AI_OUTPUT_MESSAGES_KEY, + GEN_AI_TOOL_ARGS_KEY, + GEN_AI_TOOL_CALL_RESULT_KEY, + INVOKE_AGENT_OPERATION_NAME, +) +from microsoft_agents_a365.observability.core.exporters.enriched_span import EnrichedReadableSpan +from opentelemetry.sdk.trace import ReadableSpan + +from .utils import extract_content_as_string_list + +# Semantic Kernel specific attribute keys +SK_TOOL_CALL_ARGUMENTS_KEY = "gen_ai.tool.call.arguments" +SK_TOOL_CALL_RESULT_KEY = "gen_ai.tool.call.result" + + +def enrich_semantic_kernel_span(span: ReadableSpan) -> ReadableSpan: + """ + Enricher function for Semantic Kernel spans. + + Transforms SK-specific attributes to standard gen_ai attributes + before the span is exported. Enrichment is applied based on span type: + - invoke_agent spans: Extract only content from input/output messages + - execute_tool spans: Map tool arguments and results to standard keys + + Args: + span: The ReadableSpan to enrich. + + Returns: + The enriched span (wrapped if attributes were added), or the + original span if no enrichment was needed. + """ + extra_attributes = {} + attributes = span.attributes or {} + + # Only extract content for invoke_agent spans + if span.name.startswith(INVOKE_AGENT_OPERATION_NAME): + # Transform SK-specific agent invocation attributes to standard gen_ai attributes + # Extract only the content from the full message objects + # Support both gen_ai.agent.invocation_input and gen_ai.input_messages as sources + input_messages = attributes.get("gen_ai.agent.invocation_input") or attributes.get( + GEN_AI_INPUT_MESSAGES_KEY + ) + if input_messages: + extra_attributes[GEN_AI_INPUT_MESSAGES_KEY] = extract_content_as_string_list( + input_messages + ) + + output_messages = attributes.get("gen_ai.agent.invocation_output") or attributes.get( + GEN_AI_OUTPUT_MESSAGES_KEY + ) + if output_messages: + extra_attributes[GEN_AI_OUTPUT_MESSAGES_KEY] = extract_content_as_string_list( + output_messages + ) + + # Map tool attributes for execute_tool spans + elif span.name.startswith(EXECUTE_TOOL_OPERATION_NAME): + if SK_TOOL_CALL_ARGUMENTS_KEY in attributes: + extra_attributes[GEN_AI_TOOL_ARGS_KEY] = attributes[SK_TOOL_CALL_ARGUMENTS_KEY] + + if SK_TOOL_CALL_RESULT_KEY in attributes: + extra_attributes[GEN_AI_TOOL_CALL_RESULT_KEY] = attributes[SK_TOOL_CALL_RESULT_KEY] + + if extra_attributes: + return EnrichedReadableSpan(span, extra_attributes) + + return span diff --git a/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/span_processor.py b/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/span_processor.py index c78f748a..6763f71d 100644 --- a/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/span_processor.py +++ b/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/span_processor.py @@ -1,11 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# Custom Span Processor - -from microsoft_agents_a365.observability.core.constants import GEN_AI_OPERATION_NAME_KEY +from microsoft_agents_a365.observability.core.constants import ( + GEN_AI_EXECUTION_TYPE_KEY, + GEN_AI_OPERATION_NAME_KEY, + INVOKE_AGENT_OPERATION_NAME, +) +from microsoft_agents_a365.observability.core.execution_type import ExecutionType from microsoft_agents_a365.observability.core.inference_operation_type import InferenceOperationType from microsoft_agents_a365.observability.core.utils import extract_model_name +from opentelemetry import context as context_api +from opentelemetry.sdk.trace import ReadableSpan, Span from opentelemetry.sdk.trace.export import SpanProcessor @@ -15,13 +20,42 @@ class SemanticKernelSpanProcessor(SpanProcessor): """ def __init__(self, service_name: str | None = None): + """ + Initialize the Semantic Kernel span processor. + + Args: + service_name: Optional service name for span enrichment. + """ self.service_name = service_name - def on_start(self, span, parent_context): + def on_start(self, span: Span, parent_context: context_api.Context | None) -> None: + """ + Modify span while it's still writable. + + Args: + span: The span that is starting (writable). + parent_context: The parent context of the span. + """ if span.name.startswith("chat."): span.set_attribute(GEN_AI_OPERATION_NAME_KEY, InferenceOperationType.CHAT.value.lower()) model_name = extract_model_name(span.name) span.update_name(f"{InferenceOperationType.CHAT.value.lower()} {model_name}") - def on_end(self, span): + if span.name.startswith(INVOKE_AGENT_OPERATION_NAME): + span.set_attribute( + GEN_AI_EXECUTION_TYPE_KEY, ExecutionType.HUMAN_TO_AGENT.value.lower() + ) + + def on_end(self, span: ReadableSpan) -> None: + """ + Called when a span ends. + """ pass + + def shutdown(self) -> None: + """Shutdown the processor.""" + pass + + def force_flush(self, timeout_millis: int = 30000) -> bool: + """Force flush any pending spans.""" + return True diff --git a/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/trace_instrumentor.py b/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/trace_instrumentor.py index bd44bb42..643119ce 100644 --- a/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/trace_instrumentor.py +++ b/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/trace_instrumentor.py @@ -6,30 +6,35 @@ from collections.abc import Collection from typing import Any -from microsoft_agents_a365.observability.core.config import get_tracer_provider, is_configured +from microsoft_agents_a365.observability.core.config import ( + get_tracer_provider, + is_configured, +) +from microsoft_agents_a365.observability.core.exporters.enriching_span_processor import ( + register_span_enricher, + unregister_span_enricher, +) from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from microsoft_agents_a365.observability.extensions.semantickernel.span_enricher import ( + enrich_semantic_kernel_span, +) from microsoft_agents_a365.observability.extensions.semantickernel.span_processor import ( SemanticKernelSpanProcessor, ) -# ----------------------------- -# 3) The Instrumentor class -# ----------------------------- _instruments = ("semantic-kernel >= 1.0.0",) class SemanticKernelInstrumentor(BaseInstrumentor): """ - Instruments Semantic Kernel: - • Installs your custom OTel SpanProcessor - • (Optionally) attaches an SK function-invocation filter to enrich spans + Instruments Semantic Kernel with Agent365 observability. """ def __init__(self): if not is_configured(): raise RuntimeError( - "Microsoft Agent 365 (or your telemetry config) is not initialized. Configure it before instrumenting." + "Microsoft Agent 365 is not initialized. Call configure() before instrumenting." ) super().__init__() @@ -38,13 +43,32 @@ def instrumentation_dependencies(self) -> Collection[str]: def _instrument(self, **kwargs: Any) -> None: """ - kwargs (all optional): - """ + Instrument Semantic Kernel. - # Ensure we have an SDK TracerProvider + Args: + **kwargs: Optional configuration parameters. + """ provider = get_tracer_provider() + + # Add processor for on_start modifications (rename spans, add attributes) self._processor = SemanticKernelSpanProcessor() provider.add_span_processor(self._processor) + # Register enricher for on_end modifications + # This enricher runs before the span is exported, allowing us to + # transform SK-specific attributes to standard gen_ai attributes + register_span_enricher(enrich_semantic_kernel_span) + def _uninstrument(self, **kwargs: Any) -> None: - pass + """ + Remove Semantic Kernel instrumentation. + + Args: + **kwargs: Optional configuration parameters. + """ + # Unregister the enricher + unregister_span_enricher() + + # Shutdown the processor + if hasattr(self, "_processor"): + self._processor.shutdown() diff --git a/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/utils.py b/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/utils.py new file mode 100644 index 00000000..dd1515b7 --- /dev/null +++ b/libraries/microsoft-agents-a365-observability-extensions-semantickernel/microsoft_agents_a365/observability/extensions/semantickernel/utils.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Utility functions for Semantic Kernel observability extensions.""" + +from __future__ import annotations + +import json + + +def extract_content_as_string_list(messages_json: str) -> str: + """Extract content values from messages JSON and return as JSON string list. + + Transforms from: [{"role": "user", "content": "Hello"}] + To: ["Hello"] + + Args: + messages_json: JSON string like '[{"role": "user", "content": "Hello"}]' + + Returns: + JSON string containing only the content values as an array, + or the original string if parsing fails. + """ + try: + messages = json.loads(messages_json) + if isinstance(messages, list): + contents = [] + for msg in messages: + if isinstance(msg, dict) and "content" in msg: + contents.append(msg["content"]) + elif isinstance(msg, str): + contents.append(msg) + return json.dumps(contents) + return messages_json + except (json.JSONDecodeError, TypeError): + # If parsing fails, return as-is + return messages_json diff --git a/libraries/microsoft-agents-a365-tooling-extensions-agentframework/microsoft_agents_a365/tooling/extensions/agentframework/services/mcp_tool_registration_service.py b/libraries/microsoft-agents-a365-tooling-extensions-agentframework/microsoft_agents_a365/tooling/extensions/agentframework/services/mcp_tool_registration_service.py index e895db06..6c7cc540 100644 --- a/libraries/microsoft-agents-a365-tooling-extensions-agentframework/microsoft_agents_a365/tooling/extensions/agentframework/services/mcp_tool_registration_service.py +++ b/libraries/microsoft-agents-a365-tooling-extensions-agentframework/microsoft_agents_a365/tooling/extensions/agentframework/services/mcp_tool_registration_service.py @@ -9,6 +9,7 @@ from agent_framework import ChatAgent, ChatMessage, ChatMessageStoreProtocol, MCPStreamableHTTPTool from agent_framework.azure import AzureOpenAIChatClient from agent_framework.openai import OpenAIChatClient +import httpx from microsoft_agents.hosting.core import Authorization, TurnContext @@ -24,6 +25,10 @@ ) +# Default timeout for MCP server HTTP requests (in seconds) +MCP_HTTP_CLIENT_TIMEOUT_SECONDS = 90.0 + + class McpToolRegistrationService: """ Provides MCP tool registration services for Agent Framework agents. @@ -46,6 +51,7 @@ def __init__(self, logger: Optional[logging.Logger] = None): logger=self._logger ) self._connected_servers = [] + self._http_clients: List[httpx.AsyncClient] = [] async def add_tool_servers_to_agent( self, @@ -114,11 +120,17 @@ async def add_tool_servers_to_agent( self._orchestrator_name ) - # Create and configure MCPStreamableHTTPTool + # Create httpx client with auth headers configured + http_client = httpx.AsyncClient( + headers=headers, timeout=MCP_HTTP_CLIENT_TIMEOUT_SECONDS + ) + self._http_clients.append(http_client) + + # Create and configure MCPStreamableHTTPTool with http_client mcp_tools = MCPStreamableHTTPTool( name=server_name, url=config.url, - headers=headers, + http_client=http_client, description=f"MCP tools from {server_name}", ) @@ -339,12 +351,21 @@ async def send_chat_history_from_store( async def cleanup(self): """Clean up any resources used by the service.""" try: + # Close MCP server connections for plugin in self._connected_servers: try: if hasattr(plugin, "close"): await plugin.close() except Exception as cleanup_ex: - self._logger.debug(f"Error during cleanup: {cleanup_ex}") + self._logger.debug(f"Error during plugin cleanup: {cleanup_ex}") self._connected_servers.clear() + + # Close httpx clients to prevent connection/file descriptor leaks + for http_client in self._http_clients: + try: + await http_client.aclose() + except Exception as client_ex: + self._logger.debug(f"Error closing http client: {client_ex}") + self._http_clients.clear() except Exception as ex: self._logger.debug(f"Error during service cleanup: {ex}") diff --git a/libraries/microsoft-agents-a365-tooling-extensions-agentframework/pyproject.toml b/libraries/microsoft-agents-a365-tooling-extensions-agentframework/pyproject.toml index e155ce6d..9c1d8452 100644 --- a/libraries/microsoft-agents-a365-tooling-extensions-agentframework/pyproject.toml +++ b/libraries/microsoft-agents-a365-tooling-extensions-agentframework/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "agent-framework-azure-ai >= 1.0.0b251114", "azure-identity >= 1.12.0", "typing-extensions >= 4.0.0", + "httpx >= 0.27.0", ] [project.urls] diff --git a/tests/observability/core/exporters/test_enriched_span.py b/tests/observability/core/exporters/test_enriched_span.py new file mode 100644 index 00000000..1c4e1f48 --- /dev/null +++ b/tests/observability/core/exporters/test_enriched_span.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for EnrichedReadableSpan.""" + +import unittest +from unittest.mock import Mock + +from microsoft_agents_a365.observability.core.exporters.enriched_span import EnrichedReadableSpan + + +class TestEnrichedReadableSpan(unittest.TestCase): + """Test suite for EnrichedReadableSpan.""" + + def test_attributes_merges_original_and_extra(self): + """Test that attributes property merges original span attributes with extra attributes.""" + # Create mock span with original attributes + mock_span = Mock() + mock_span.attributes = {"original_key": "original_value", "shared_key": "original"} + + # Create enriched span with extra attributes + extra_attributes = {"extra_key": "extra_value", "shared_key": "overwritten"} + enriched_span = EnrichedReadableSpan(mock_span, extra_attributes) + + # Verify merged attributes + attributes = enriched_span.attributes + self.assertEqual(attributes["original_key"], "original_value") + self.assertEqual(attributes["extra_key"], "extra_value") + self.assertEqual(attributes["shared_key"], "overwritten") # Extra should overwrite original + + def test_delegates_all_properties_to_wrapped_span(self): + """Test that all span properties are delegated to the wrapped span.""" + # Create mock span with all properties + mock_span = Mock() + mock_span.name = "test-span" + mock_span.context = Mock(trace_id=123, span_id=456) + mock_span.parent = Mock(span_id=789) + mock_span.start_time = 1000000000 + mock_span.end_time = 2000000000 + mock_span.status = Mock(status_code="OK", description=None) + mock_span.kind = "INTERNAL" + mock_span.events = [] + mock_span.links = [] + mock_span.resource = Mock(attributes={"service.name": "test"}) + mock_span.instrumentation_scope = Mock(name="test-scope") + mock_span.attributes = {} + + enriched_span = EnrichedReadableSpan(mock_span, {}) + + # Verify all properties delegate correctly + self.assertEqual(enriched_span.name, "test-span") + self.assertEqual(enriched_span.context, mock_span.context) + self.assertEqual(enriched_span.parent, mock_span.parent) + self.assertEqual(enriched_span.start_time, 1000000000) + self.assertEqual(enriched_span.end_time, 2000000000) + self.assertEqual(enriched_span.status, mock_span.status) + self.assertEqual(enriched_span.kind, "INTERNAL") + self.assertEqual(enriched_span.events, []) + self.assertEqual(enriched_span.links, []) + self.assertEqual(enriched_span.resource, mock_span.resource) + self.assertEqual(enriched_span.instrumentation_scope, mock_span.instrumentation_scope) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/observability/core/exporters/test_enriching_span_processor.py b/tests/observability/core/exporters/test_enriching_span_processor.py new file mode 100644 index 00000000..759a07a4 --- /dev/null +++ b/tests/observability/core/exporters/test_enriching_span_processor.py @@ -0,0 +1,154 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for enriching_span_processor module.""" + +import unittest +from unittest.mock import Mock + +from microsoft_agents_a365.observability.core.exporters.enriching_span_processor import ( + _EnrichingBatchSpanProcessor, + get_span_enricher, + register_span_enricher, + unregister_span_enricher, +) + + +class TestSpanEnricherRegistry(unittest.TestCase): + """Test suite for span enricher registration functions.""" + + def setUp(self): + """Ensure clean state before each test.""" + unregister_span_enricher() + + def tearDown(self): + """Clean up after each test.""" + unregister_span_enricher() + + def test_register_and_unregister_enricher(self): + """Test that enricher can be registered and unregistered.""" + + def my_enricher(span): + return span + + # Initially no enricher + self.assertIsNone(get_span_enricher()) + + # Register + register_span_enricher(my_enricher) + self.assertEqual(get_span_enricher(), my_enricher) + + # Unregister + unregister_span_enricher() + self.assertIsNone(get_span_enricher()) + + def test_register_second_enricher_raises_error(self): + """Test that registering a second enricher raises RuntimeError.""" + + def enricher_one(span): + return span + + def enricher_two(span): + return span + + register_span_enricher(enricher_one) + + with self.assertRaises(RuntimeError) as context: + register_span_enricher(enricher_two) + + self.assertIn("already registered", str(context.exception)) + + def test_unregister_when_none_registered_is_safe(self): + """Test that unregistering when no enricher is registered doesn't raise.""" + # Should not raise + unregister_span_enricher() + self.assertIsNone(get_span_enricher()) + + +class TestEnrichingBatchSpanProcessor(unittest.TestCase): + """Test suite for _EnrichingBatchSpanProcessor.""" + + def setUp(self): + """Ensure clean state before each test.""" + unregister_span_enricher() + + def tearDown(self): + """Clean up after each test.""" + unregister_span_enricher() + + def test_on_end_applies_enricher_to_span(self): + """Test that on_end applies the registered enricher to the span.""" + # Create processor with a mock exporter + mock_exporter = Mock() + processor = _EnrichingBatchSpanProcessor(mock_exporter) + + # Register an enricher that tracks what it receives and returns + received_spans = [] + + def enricher(span): + received_spans.append(span) + # Return a mock enriched span + enriched = Mock(name="enriched_span") + enriched.context = span.context + return enriched + + register_span_enricher(enricher) + + # Create a mock span + original_span = Mock(name="original_span") + original_span.context = Mock() + original_span.context.trace_id = 123 + original_span.context.span_id = 456 + + # Call on_end + processor.on_end(original_span) + + # Verify enricher was called with the original span + self.assertEqual(len(received_spans), 1) + self.assertEqual(received_spans[0], original_span) + + # Cleanup + processor.shutdown() + + def test_on_end_continues_if_enricher_raises_exception(self): + """Test that on_end continues processing even if enricher raises an exception.""" + mock_exporter = Mock() + processor = _EnrichingBatchSpanProcessor(mock_exporter) + + def failing_enricher(span): + raise ValueError("Enricher failed!") + + register_span_enricher(failing_enricher) + + # Create a mock span + original_span = Mock(name="original_span") + original_span.context = Mock() + original_span.context.trace_id = 123 + original_span.context.span_id = 456 + + # Should not raise despite failing enricher + processor.on_end(original_span) + + # Cleanup + processor.shutdown() + + def test_on_end_works_without_enricher(self): + """Test that on_end works when no enricher is registered.""" + mock_exporter = Mock() + processor = _EnrichingBatchSpanProcessor(mock_exporter) + + # Create a mock span (no enricher registered) + original_span = Mock(name="original_span") + original_span.context = Mock() + original_span.context.trace_id = 123 + original_span.context.span_id = 456 + + # Should not raise + processor.on_end(original_span) + + # Cleanup + processor.shutdown() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/observability/core/test_agent365.py b/tests/observability/core/test_agent365.py index 7ae5086e..358b9351 100644 --- a/tests/observability/core/test_agent365.py +++ b/tests/observability/core/test_agent365.py @@ -84,7 +84,7 @@ def test_configure_with_exporter_options_and_parameter_precedence(self, mock_is_ self.assertTrue(result, "configure() should return True with exporter_options") @patch("microsoft_agents_a365.observability.core.config._Agent365Exporter") - @patch("microsoft_agents_a365.observability.core.config.BatchSpanProcessor") + @patch("microsoft_agents_a365.observability.core.config._EnrichingBatchSpanProcessor") @patch("microsoft_agents_a365.observability.core.config.is_agent365_exporter_enabled") def test_batch_span_processor_and_exporter_called_with_correct_values( self, mock_is_enabled, mock_batch_processor, mock_exporter @@ -198,7 +198,7 @@ def test_configure_uses_existing_tracer_provider(self, mock_get_provider, mock_i # Verify types of processors processor_types = [type(p).__name__ for p in processors] - self.assertIn("BatchSpanProcessor", processor_types) + self.assertIn("_EnrichingBatchSpanProcessor", processor_types) self.assertIn("SpanProcessor", processor_types) diff --git a/tests/observability/extensions/semantickernel/test_span_enricher.py b/tests/observability/extensions/semantickernel/test_span_enricher.py new file mode 100644 index 00000000..c8817838 --- /dev/null +++ b/tests/observability/extensions/semantickernel/test_span_enricher.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for Semantic Kernel span enricher.""" + +import unittest +from unittest.mock import Mock + +from microsoft_agents_a365.observability.core.constants import ( + GEN_AI_INPUT_MESSAGES_KEY, + GEN_AI_OUTPUT_MESSAGES_KEY, +) +from microsoft_agents_a365.observability.extensions.semantickernel.span_enricher import ( + enrich_semantic_kernel_span, +) + + +class TestSemanticKernelSpanEnricher(unittest.TestCase): + """Test suite for enrich_semantic_kernel_span function.""" + + def test_invoke_agent_span_extracts_content_from_messages(self): + """Test that invoke_agent spans have content extracted from input/output messages.""" + # Create a mock span with invoke_agent name and message attributes + mock_span = Mock() + mock_span.name = "invoke_agent test-agent" + mock_span.attributes = { + "gen_ai.agent.invocation_input": '[{"role": "user", "content": "Hello"}]', + "gen_ai.agent.invocation_output": '[{"role": "assistant", "content": "Hi there!"}]', + } + + # Enrich the span + enriched = enrich_semantic_kernel_span(mock_span) + + # Verify it returns an EnrichedReadableSpan with extracted content + self.assertNotEqual(enriched, mock_span) + attributes = enriched.attributes + # extract_content_as_string_list returns a JSON string + self.assertEqual(attributes[GEN_AI_INPUT_MESSAGES_KEY], '["Hello"]') + self.assertEqual(attributes[GEN_AI_OUTPUT_MESSAGES_KEY], '["Hi there!"]') + + def test_non_matching_span_returns_original(self): + """Test that spans not matching invoke_agent or execute_tool are returned unchanged.""" + # Create a mock span with a different operation name + mock_span = Mock() + mock_span.name = "some_other_operation" + mock_span.attributes = { + "some_key": "some_value", + } + + # Enrich the span + result = enrich_semantic_kernel_span(mock_span) + + # Verify it returns the original span unchanged + self.assertEqual(result, mock_span) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tooling/extensions/agentframework/services/test_mcp_tool_registration_service.py b/tests/tooling/extensions/agentframework/services/test_mcp_tool_registration_service.py new file mode 100644 index 00000000..4282a5c7 --- /dev/null +++ b/tests/tooling/extensions/agentframework/services/test_mcp_tool_registration_service.py @@ -0,0 +1,727 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for add_tool_servers_to_agent method in McpToolRegistrationService. + +These tests verify that httpx.AsyncClient is correctly configured with +Authorization and User-Agent headers, and that MCPStreamableHTTPTool is +instantiated with the http_client parameter (not headers). + +This prevents regressions to the bug where passing headers directly to +MCPStreamableHTTPTool via **kwargs was silently ignored, causing 400 Bad Request +errors when calling MCP tool servers. +""" + +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +from microsoft_agents_a365.tooling.extensions.agentframework.services import ( + McpToolRegistrationService, +) +from microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service import ( + MCP_HTTP_CLIENT_TIMEOUT_SECONDS, +) +from microsoft_agents_a365.tooling.utils.constants import Constants + + +class TestAddToolServersHttpxClientConfiguration: + """Tests for httpx.AsyncClient configuration in add_tool_servers_to_agent.""" + + @pytest.fixture + def mock_turn_context(self): + """Create a mock TurnContext.""" + mock_context = Mock() + mock_activity = Mock() + mock_conversation = Mock() + + mock_conversation.id = "conv-test-123" + mock_activity.conversation = mock_conversation + mock_activity.id = "msg-test-456" + + mock_context.activity = mock_activity + return mock_context + + @pytest.fixture + def mock_auth(self): + """Create a mock Authorization that returns a token on exchange.""" + mock_auth = AsyncMock() + mock_token_result = Mock() + mock_token_result.token = "test-auth-token-12345" + mock_auth.exchange_token = AsyncMock(return_value=mock_token_result) + return mock_auth + + @pytest.fixture + def mock_chat_client(self): + """Create a mock OpenAIChatClient or AzureOpenAIChatClient.""" + return Mock() + + @pytest.fixture + def mock_mcp_server_config(self): + """Create a mock MCP server configuration.""" + config = Mock() + config.mcp_server_name = "test-mcp-server" + config.mcp_server_unique_name = "test-mcp-server-unique" + config.url = "https://test-mcp-server.example.com/api" + return config + + @pytest.fixture + def service(self): + """Create McpToolRegistrationService instance.""" + return McpToolRegistrationService() + + @pytest.mark.asyncio + @pytest.mark.unit + async def test_httpx_client_has_authorization_header( + self, + service, + mock_turn_context, + mock_auth, + mock_chat_client, + mock_mcp_server_config, + ): + """Test that httpx.AsyncClient is created with Authorization header.""" + auth_token = "test-bearer-token-xyz" + + with ( + patch.object( + service._mcp_server_configuration_service, + "list_tool_servers", + new_callable=AsyncMock, + return_value=[mock_mcp_server_config], + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.httpx.AsyncClient" + ) as mock_httpx_client, + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.MCPStreamableHTTPTool" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.ChatAgent" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.resolve_agent_identity", + return_value="test-agent-id", + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.get_user_agent_header", + return_value="TestAgent/1.0", + ), + ): + mock_http_client_instance = MagicMock() + mock_httpx_client.return_value = mock_http_client_instance + + await service.add_tool_servers_to_agent( + chat_client=mock_chat_client, + agent_instructions="Test instructions", + initial_tools=[], + auth=mock_auth, + auth_handler_name="test-auth-handler", + turn_context=mock_turn_context, + auth_token=auth_token, + ) + + # Verify httpx.AsyncClient was called with headers containing Authorization + mock_httpx_client.assert_called_once() + call_kwargs = mock_httpx_client.call_args[1] + + assert "headers" in call_kwargs + expected_auth_header = f"{Constants.Headers.BEARER_PREFIX} {auth_token}" + assert call_kwargs["headers"][Constants.Headers.AUTHORIZATION] == expected_auth_header + + @pytest.mark.asyncio + @pytest.mark.unit + async def test_httpx_client_has_user_agent_header( + self, + service, + mock_turn_context, + mock_auth, + mock_chat_client, + mock_mcp_server_config, + ): + """Test that httpx.AsyncClient is created with User-Agent header.""" + auth_token = "test-bearer-token-xyz" + expected_user_agent = "AgentFramework/1.0" + + with ( + patch.object( + service._mcp_server_configuration_service, + "list_tool_servers", + new_callable=AsyncMock, + return_value=[mock_mcp_server_config], + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.httpx.AsyncClient" + ) as mock_httpx_client, + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.MCPStreamableHTTPTool" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.ChatAgent" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.resolve_agent_identity", + return_value="test-agent-id", + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.get_user_agent_header", + return_value=expected_user_agent, + ), + ): + mock_http_client_instance = MagicMock() + mock_httpx_client.return_value = mock_http_client_instance + + await service.add_tool_servers_to_agent( + chat_client=mock_chat_client, + agent_instructions="Test instructions", + initial_tools=[], + auth=mock_auth, + auth_handler_name="test-auth-handler", + turn_context=mock_turn_context, + auth_token=auth_token, + ) + + # Verify httpx.AsyncClient was called with User-Agent header + mock_httpx_client.assert_called_once() + call_kwargs = mock_httpx_client.call_args[1] + + assert "headers" in call_kwargs + assert call_kwargs["headers"][Constants.Headers.USER_AGENT] == expected_user_agent + + @pytest.mark.asyncio + @pytest.mark.unit + async def test_httpx_client_has_correct_timeout( + self, + service, + mock_turn_context, + mock_auth, + mock_chat_client, + mock_mcp_server_config, + ): + """Test that httpx.AsyncClient is created with the defined timeout constant.""" + auth_token = "test-bearer-token-xyz" + + with ( + patch.object( + service._mcp_server_configuration_service, + "list_tool_servers", + new_callable=AsyncMock, + return_value=[mock_mcp_server_config], + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.httpx.AsyncClient" + ) as mock_httpx_client, + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.MCPStreamableHTTPTool" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.ChatAgent" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.resolve_agent_identity", + return_value="test-agent-id", + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.get_user_agent_header", + return_value="TestAgent/1.0", + ), + ): + mock_http_client_instance = MagicMock() + mock_httpx_client.return_value = mock_http_client_instance + + await service.add_tool_servers_to_agent( + chat_client=mock_chat_client, + agent_instructions="Test instructions", + initial_tools=[], + auth=mock_auth, + auth_handler_name="test-auth-handler", + turn_context=mock_turn_context, + auth_token=auth_token, + ) + + # Verify httpx.AsyncClient was called with correct timeout + mock_httpx_client.assert_called_once() + call_kwargs = mock_httpx_client.call_args[1] + + assert "timeout" in call_kwargs + assert call_kwargs["timeout"] == MCP_HTTP_CLIENT_TIMEOUT_SECONDS + + @pytest.mark.asyncio + @pytest.mark.unit + async def test_mcp_tool_receives_http_client_not_headers( + self, + service, + mock_turn_context, + mock_auth, + mock_chat_client, + mock_mcp_server_config, + ): + """Test that MCPStreamableHTTPTool is instantiated with http_client parameter. + + This is the critical test that prevents regression to the bug where + headers were passed directly to MCPStreamableHTTPTool and silently ignored. + The fix passes an httpx.AsyncClient with pre-configured headers instead. + """ + auth_token = "test-bearer-token-xyz" + + with ( + patch.object( + service._mcp_server_configuration_service, + "list_tool_servers", + new_callable=AsyncMock, + return_value=[mock_mcp_server_config], + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.httpx.AsyncClient" + ) as mock_httpx_client, + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.MCPStreamableHTTPTool" + ) as mock_mcp_tool, + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.ChatAgent" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.resolve_agent_identity", + return_value="test-agent-id", + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.get_user_agent_header", + return_value="TestAgent/1.0", + ), + ): + mock_http_client_instance = MagicMock() + mock_httpx_client.return_value = mock_http_client_instance + + await service.add_tool_servers_to_agent( + chat_client=mock_chat_client, + agent_instructions="Test instructions", + initial_tools=[], + auth=mock_auth, + auth_handler_name="test-auth-handler", + turn_context=mock_turn_context, + auth_token=auth_token, + ) + + # Verify MCPStreamableHTTPTool was called with http_client, NOT headers + mock_mcp_tool.assert_called_once() + call_kwargs = mock_mcp_tool.call_args[1] + + # Critical: http_client must be passed (this is the fix) + assert "http_client" in call_kwargs + assert call_kwargs["http_client"] is mock_http_client_instance + + # Critical: headers must NOT be passed directly (this was the bug) + assert "headers" not in call_kwargs + + @pytest.mark.asyncio + @pytest.mark.unit + async def test_httpx_client_added_to_internal_list_for_cleanup( + self, + service, + mock_turn_context, + mock_auth, + mock_chat_client, + mock_mcp_server_config, + ): + """Test that created httpx clients are tracked in _http_clients for cleanup.""" + auth_token = "test-bearer-token-xyz" + + with ( + patch.object( + service._mcp_server_configuration_service, + "list_tool_servers", + new_callable=AsyncMock, + return_value=[mock_mcp_server_config], + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.httpx.AsyncClient" + ) as mock_httpx_client, + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.MCPStreamableHTTPTool" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.ChatAgent" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.resolve_agent_identity", + return_value="test-agent-id", + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.get_user_agent_header", + return_value="TestAgent/1.0", + ), + ): + mock_http_client_instance = MagicMock() + mock_httpx_client.return_value = mock_http_client_instance + + # Clear any pre-existing clients + service._http_clients.clear() + + await service.add_tool_servers_to_agent( + chat_client=mock_chat_client, + agent_instructions="Test instructions", + initial_tools=[], + auth=mock_auth, + auth_handler_name="test-auth-handler", + turn_context=mock_turn_context, + auth_token=auth_token, + ) + + # Verify httpx client was added to internal tracking list + assert len(service._http_clients) == 1 + assert service._http_clients[0] is mock_http_client_instance + + +class TestMcpToolRegistrationServiceCleanup: + """Tests for cleanup method to ensure httpx clients are properly closed.""" + + @pytest.fixture + def service(self): + """Create McpToolRegistrationService instance.""" + return McpToolRegistrationService() + + @pytest.mark.asyncio + @pytest.mark.unit + async def test_cleanup_closes_all_httpx_clients(self, service): + """Test that cleanup properly closes all tracked httpx clients.""" + # Create mock httpx clients + mock_client1 = AsyncMock() + mock_client2 = AsyncMock() + + service._http_clients = [mock_client1, mock_client2] + + await service.cleanup() + + # Verify both clients had aclose() called + mock_client1.aclose.assert_called_once() + mock_client2.aclose.assert_called_once() + + # Verify the list was cleared + assert len(service._http_clients) == 0 + + @pytest.mark.asyncio + @pytest.mark.unit + async def test_cleanup_handles_client_close_errors_gracefully(self, service): + """Test that cleanup continues even if a client close raises an exception.""" + # Create mock clients - first one raises, second should still be closed + mock_client1 = AsyncMock() + mock_client1.aclose.side_effect = Exception("Connection error") + mock_client2 = AsyncMock() + + service._http_clients = [mock_client1, mock_client2] + + # Should not raise + await service.cleanup() + + # Both clients should have had aclose attempted + mock_client1.aclose.assert_called_once() + mock_client2.aclose.assert_called_once() + + # List should be cleared even after errors + assert len(service._http_clients) == 0 + + +class TestHttpxClientLifecycle: + """End-to-end tests for httpx client lifecycle management. + + These tests verify that clients created during add_tool_servers_to_agent() + are properly tracked and cleaned up by cleanup(), preventing connection + and file descriptor leaks. + """ + + @pytest.fixture + def mock_turn_context(self): + """Create a mock TurnContext.""" + mock_context = Mock() + mock_activity = Mock() + mock_conversation = Mock() + + mock_conversation.id = "conv-test-123" + mock_activity.conversation = mock_conversation + mock_activity.id = "msg-test-456" + + mock_context.activity = mock_activity + return mock_context + + @pytest.fixture + def mock_auth(self): + """Create a mock Authorization that returns a token on exchange.""" + mock_auth = AsyncMock() + mock_token_result = Mock() + mock_token_result.token = "test-auth-token-12345" + mock_auth.exchange_token = AsyncMock(return_value=mock_token_result) + return mock_auth + + @pytest.fixture + def mock_chat_client(self): + """Create a mock chat client.""" + return Mock() + + @pytest.fixture + def service(self): + """Create McpToolRegistrationService instance.""" + return McpToolRegistrationService() + + @pytest.mark.asyncio + @pytest.mark.unit + async def test_full_client_lifecycle_single_server( + self, + service, + mock_turn_context, + mock_auth, + mock_chat_client, + ): + """Test full lifecycle: create client via add_tool_servers, then cleanup. + + This end-to-end test ensures that: + 1. add_tool_servers_to_agent() creates and tracks httpx clients + 2. cleanup() calls aclose() on each tracked client + 3. cleanup() clears the tracking list + """ + mock_server_config = Mock() + mock_server_config.mcp_server_name = "test-server" + mock_server_config.mcp_server_unique_name = "test-server-unique" + mock_server_config.url = "https://test.example.com/api" + + mock_http_client_instance = MagicMock() + + with ( + patch.object( + service._mcp_server_configuration_service, + "list_tool_servers", + new_callable=AsyncMock, + return_value=[mock_server_config], + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.httpx.AsyncClient" + ) as mock_httpx_client, + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.MCPStreamableHTTPTool" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.ChatAgent" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.resolve_agent_identity", + return_value="test-agent-id", + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.get_user_agent_header", + return_value="TestAgent/1.0", + ), + ): + mock_httpx_client.return_value = mock_http_client_instance + + # Step 1: Create agent with tool servers - this should create and track httpx client + await service.add_tool_servers_to_agent( + chat_client=mock_chat_client, + agent_instructions="Test instructions", + initial_tools=[], + auth=mock_auth, + auth_handler_name="test-auth-handler", + turn_context=mock_turn_context, + auth_token="test-token", + ) + + # Verify client was tracked + assert len(service._http_clients) == 1 + assert service._http_clients[0] is mock_http_client_instance + + # Step 2: Call cleanup - this should close the client + await service.cleanup() + + # Verify aclose() was called on the client created during add_tool_servers + mock_http_client_instance.aclose.assert_called_once() + + # Verify tracking list was cleared + assert len(service._http_clients) == 0 + + @pytest.mark.asyncio + @pytest.mark.unit + async def test_full_client_lifecycle_multiple_servers( + self, + service, + mock_turn_context, + mock_auth, + mock_chat_client, + ): + """Test lifecycle with multiple MCP servers creating multiple clients. + + Verifies that when multiple tool servers are configured, each gets its + own httpx client that is properly tracked and cleaned up. + """ + mock_server_config1 = Mock() + mock_server_config1.mcp_server_name = "server-1" + mock_server_config1.mcp_server_unique_name = "server-1-unique" + mock_server_config1.url = "https://server1.example.com/api" + + mock_server_config2 = Mock() + mock_server_config2.mcp_server_name = "server-2" + mock_server_config2.mcp_server_unique_name = "server-2-unique" + mock_server_config2.url = "https://server2.example.com/api" + + mock_server_config3 = Mock() + mock_server_config3.mcp_server_name = "server-3" + mock_server_config3.mcp_server_unique_name = "server-3-unique" + mock_server_config3.url = "https://server3.example.com/api" + + # Create unique mock clients for each server + mock_clients = [MagicMock() for _ in range(3)] + client_iter = iter(mock_clients) + + with ( + patch.object( + service._mcp_server_configuration_service, + "list_tool_servers", + new_callable=AsyncMock, + return_value=[mock_server_config1, mock_server_config2, mock_server_config3], + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.httpx.AsyncClient" + ) as mock_httpx_client, + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.MCPStreamableHTTPTool" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.ChatAgent" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.resolve_agent_identity", + return_value="test-agent-id", + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.get_user_agent_header", + return_value="TestAgent/1.0", + ), + ): + # Return a different mock client for each call + mock_httpx_client.side_effect = lambda **kwargs: next(client_iter) + + # Step 1: Create agent with multiple tool servers + await service.add_tool_servers_to_agent( + chat_client=mock_chat_client, + agent_instructions="Test instructions", + initial_tools=[], + auth=mock_auth, + auth_handler_name="test-auth-handler", + turn_context=mock_turn_context, + auth_token="test-token", + ) + + # Verify all 3 clients were tracked + assert len(service._http_clients) == 3 + for i, client in enumerate(mock_clients): + assert service._http_clients[i] is client + + # Step 2: Call cleanup + await service.cleanup() + + # Verify aclose() was called on ALL clients + for client in mock_clients: + client.aclose.assert_called_once() + + # Verify tracking list was cleared + assert len(service._http_clients) == 0 + + @pytest.mark.asyncio + @pytest.mark.unit + async def test_cleanup_idempotent_no_clients(self, service): + """Test that cleanup() is safe to call when no clients exist. + + Ensures cleanup doesn't raise errors when called on a fresh service + or called multiple times. + """ + # Verify initial state is empty + assert len(service._http_clients) == 0 + + # Should not raise when no clients to clean up + await service.cleanup() + + # Still empty + assert len(service._http_clients) == 0 + + # Safe to call again + await service.cleanup() + assert len(service._http_clients) == 0 + + @pytest.mark.asyncio + @pytest.mark.unit + async def test_cleanup_called_twice_after_creating_clients( + self, + service, + mock_turn_context, + mock_auth, + mock_chat_client, + ): + """Test that calling cleanup() twice doesn't cause issues. + + After the first cleanup, the list is cleared, so the second cleanup + should be a no-op without errors. + """ + mock_server_config = Mock() + mock_server_config.mcp_server_name = "test-server" + mock_server_config.mcp_server_unique_name = "test-server-unique" + mock_server_config.url = "https://test.example.com/api" + + mock_http_client_instance = MagicMock() + + with ( + patch.object( + service._mcp_server_configuration_service, + "list_tool_servers", + new_callable=AsyncMock, + return_value=[mock_server_config], + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.httpx.AsyncClient" + ) as mock_httpx_client, + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.MCPStreamableHTTPTool" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.ChatAgent" + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.resolve_agent_identity", + return_value="test-agent-id", + ), + patch( + "microsoft_agents_a365.tooling.extensions.agentframework.services.mcp_tool_registration_service.Utility.get_user_agent_header", + return_value="TestAgent/1.0", + ), + ): + mock_httpx_client.return_value = mock_http_client_instance + + await service.add_tool_servers_to_agent( + chat_client=mock_chat_client, + agent_instructions="Test instructions", + initial_tools=[], + auth=mock_auth, + auth_handler_name="test-auth-handler", + turn_context=mock_turn_context, + auth_token="test-token", + ) + + # First cleanup + await service.cleanup() + mock_http_client_instance.aclose.assert_called_once() + assert len(service._http_clients) == 0 + + # Second cleanup should be safe (no-op) + await service.cleanup() + # aclose still only called once (not twice) + mock_http_client_instance.aclose.assert_called_once() + assert len(service._http_clients) == 0 + + +class TestMcpHttpClientTimeoutConstant: + """Tests for the MCP_HTTP_CLIENT_TIMEOUT_SECONDS constant.""" + + @pytest.mark.unit + def test_timeout_constant_is_90_seconds(self): + """Verify the timeout constant has the expected value.""" + assert MCP_HTTP_CLIENT_TIMEOUT_SECONDS == 90.0 + + @pytest.mark.unit + def test_timeout_constant_is_float(self): + """Verify the timeout constant is a float for httpx compatibility.""" + assert isinstance(MCP_HTTP_CLIENT_TIMEOUT_SECONDS, float)