diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index bc16d9edb9..fb0dd2c7e0 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -886,6 +886,8 @@ def _parse_annotation(annotation: Any) -> Any: If the second annotation (after the type) is a string, then we convert that to a Pydantic Field description. The rest are returned as-is, allowing for multiple annotations. + Literal types are returned as-is to preserve their enum-like values. + Args: annotation: The type annotation to parse. @@ -894,6 +896,12 @@ def _parse_annotation(annotation: Any) -> Any: """ origin = get_origin(annotation) if origin is not None: + # Literal types should be returned as-is - their args are the allowed values, + # not type annotations to be parsed. For example, Literal["Data", "Security"] + # has args ("Data", "Security") which are the valid string values. + if origin is Literal: + return annotation + args = get_args(annotation) # For other generics, return the origin type (e.g., list for List[int]) if len(args) > 1 and isinstance(args[1], str): diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index 4beee1fb7d..88c34dc3e8 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Any +from typing import Annotated, Any, Literal from unittest.mock import Mock import pytest @@ -14,7 +14,7 @@ ToolProtocol, ai_function, ) -from agent_framework._tools import _parse_inputs +from agent_framework._tools import _parse_annotation, _parse_inputs from agent_framework.exceptions import ToolException from agent_framework.observability import OtelAttr @@ -128,6 +128,95 @@ def test_tool(self, x: int, y: int) -> int: assert test_tool(1, 2) == 3 +def test_ai_function_with_literal_type_parameter(): + """Test ai_function decorator with Literal type parameter (issue #2891).""" + + @ai_function + def search_flows(category: Literal["Data", "Security", "Network"], issue: str) -> str: + """Search flows by category.""" + return f"{category}: {issue}" + + assert isinstance(search_flows, AIFunction) + schema = search_flows.parameters() + assert schema == { + "properties": { + "category": {"enum": ["Data", "Security", "Network"], "title": "Category", "type": "string"}, + "issue": {"title": "Issue", "type": "string"}, + }, + "required": ["category", "issue"], + "title": "search_flows_input", + "type": "object", + } + # Verify invocation works + assert search_flows("Data", "test issue") == "Data: test issue" + + +def test_ai_function_with_literal_type_in_class_method(): + """Test ai_function decorator with Literal type parameter in a class method (issue #2891).""" + + class MyTools: + @ai_function + def search_flows(self, category: Literal["Data", "Security", "Network"], issue: str) -> str: + """Search flows by category.""" + return f"{category}: {issue}" + + tools = MyTools() + search_tool = tools.search_flows + assert isinstance(search_tool, AIFunction) + schema = search_tool.parameters() + assert schema == { + "properties": { + "category": {"enum": ["Data", "Security", "Network"], "title": "Category", "type": "string"}, + "issue": {"title": "Issue", "type": "string"}, + }, + "required": ["category", "issue"], + "title": "search_flows_input", + "type": "object", + } + # Verify invocation works + assert search_tool("Security", "test issue") == "Security: test issue" + + +def test_ai_function_with_literal_int_type(): + """Test ai_function decorator with Literal int type parameter.""" + + @ai_function + def set_priority(priority: Literal[1, 2, 3], task: str) -> str: + """Set priority for a task.""" + return f"Priority {priority}: {task}" + + assert isinstance(set_priority, AIFunction) + schema = set_priority.parameters() + assert schema == { + "properties": { + "priority": {"enum": [1, 2, 3], "title": "Priority", "type": "integer"}, + "task": {"title": "Task", "type": "string"}, + }, + "required": ["priority", "task"], + "title": "set_priority_input", + "type": "object", + } + assert set_priority(1, "important task") == "Priority 1: important task" + + +def test_ai_function_with_literal_and_annotated(): + """Test ai_function decorator with Literal type combined with Annotated for description.""" + + @ai_function + def categorize( + category: Annotated[Literal["A", "B", "C"], "The category to assign"], + name: str, + ) -> str: + """Categorize an item.""" + return f"{category}: {name}" + + assert isinstance(categorize, AIFunction) + schema = categorize.parameters() + # Literal type inside Annotated should preserve enum values + assert schema["properties"]["category"]["enum"] == ["A", "B", "C"] + assert categorize("A", "test") == "A: test" + + async def test_ai_function_decorator_shared_state(): """Test that decorated methods maintain shared state across multiple calls and tool usage.""" @@ -1368,3 +1457,70 @@ def tool_with_kwargs(x: int, **kwargs: Any) -> str: arguments=tool_with_kwargs.input_model(x=10), ) assert result_default == "x=10, user=unknown" + + +# region _parse_annotation tests + + +def test_parse_annotation_with_literal_type(): + """Test that _parse_annotation returns Literal types unchanged (issue #2891).""" + from typing import get_args, get_origin + + # Literal with string values + literal_annotation = Literal["Data", "Security", "Network"] + result = _parse_annotation(literal_annotation) + assert result is literal_annotation + assert get_origin(result) is Literal + assert get_args(result) == ("Data", "Security", "Network") + + +def test_parse_annotation_with_literal_int_type(): + """Test that _parse_annotation returns Literal int types unchanged.""" + from typing import get_args, get_origin + + literal_annotation = Literal[1, 2, 3] + result = _parse_annotation(literal_annotation) + assert result is literal_annotation + assert get_origin(result) is Literal + assert get_args(result) == (1, 2, 3) + + +def test_parse_annotation_with_literal_bool_type(): + """Test that _parse_annotation returns Literal bool types unchanged.""" + from typing import get_args, get_origin + + literal_annotation = Literal[True, False] + result = _parse_annotation(literal_annotation) + assert result is literal_annotation + assert get_origin(result) is Literal + assert get_args(result) == (True, False) + + +def test_parse_annotation_with_simple_types(): + """Test that _parse_annotation returns simple types unchanged.""" + assert _parse_annotation(str) is str + assert _parse_annotation(int) is int + assert _parse_annotation(float) is float + assert _parse_annotation(bool) is bool + + +def test_parse_annotation_with_annotated_and_literal(): + """Test that Annotated[Literal[...], description] works correctly.""" + from typing import get_args, get_origin + + # When Literal is inside Annotated, it should still be preserved + annotated_literal = Annotated[Literal["A", "B", "C"], "The category"] + result = _parse_annotation(annotated_literal) + + # The Annotated type should be preserved + origin = get_origin(result) + assert origin is Annotated + + args = get_args(result) + # First arg is the Literal type + literal_type = args[0] + assert get_origin(literal_type) is Literal + assert get_args(literal_type) == ("A", "B", "C") + + +# endregion