Skip to content

Commit 66d675f

Browse files
committed
fix: Support string enum for elicitation
1 parent 202af49 commit 66d675f

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/mcp/server/elicitation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import types
6+
from enum import Enum, StrEnum
67
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
78

89
from pydantic import BaseModel
@@ -37,7 +38,7 @@ class CancelledElicitation(BaseModel):
3738

3839

3940
# Primitive types allowed in elicitation schemas
40-
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)
41+
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool, StrEnum)
4142

4243

4344
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
@@ -70,6 +71,10 @@ def _is_primitive_field(field_info: FieldInfo) -> bool:
7071
# All args must be primitive types or None
7172
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
7273

74+
# Handle Enum types
75+
if isinstance(annotation, type) and issubclass(annotation, str) and issubclass(annotation, Enum):
76+
return True
77+
7378
return False
7479

7580

tests/server/fastmcp/test_elicitation.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Test the elicitation feature using stdio transport.
33
"""
44

5+
from enum import StrEnum
56
from typing import Any
67

78
import pytest
@@ -142,6 +143,39 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par
142143
assert "Validation failed as expected" in result.content[0].text
143144
assert field_name in result.content[0].text
144145

146+
# Test valid Enum types (should not fail validation)
147+
class Status(StrEnum):
148+
ACTIVE = "active"
149+
INACTIVE = "inactive"
150+
151+
class ValidStrEnumSchema(BaseModel):
152+
status: Status = Field(description="Status using StrEnum")
153+
154+
def create_valid_validation_tool(name: str, schema_class: type[BaseModel]):
155+
@mcp.tool(name=name, description=f"Tool testing {name}")
156+
async def tool(ctx: Context[ServerSession, None]) -> str:
157+
# This should succeed without validation error
158+
result = await ctx.elicit(message="Test valid schema", schema=schema_class)
159+
return f"Success: {result.action}"
160+
161+
return tool
162+
163+
create_valid_validation_tool("valid_strenum", ValidStrEnumSchema)
164+
165+
async def enum_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams):
166+
# Return the required status field
167+
return ElicitResult(action="accept", content={"status": "active"})
168+
169+
async with create_connected_server_and_client_session(
170+
mcp._mcp_server, elicitation_callback=enum_callback
171+
) as client_session:
172+
await client_session.initialize()
173+
174+
result = await client_session.call_tool("valid_strenum", {})
175+
assert len(result.content) == 1
176+
assert isinstance(result.content[0], TextContent)
177+
assert "Success: accept" == result.content[0].text
178+
145179

146180
@pytest.mark.anyio
147181
async def test_elicitation_with_optional_fields():

0 commit comments

Comments
 (0)