|
2 | 2 | Test the elicitation feature using stdio transport. |
3 | 3 | """ |
4 | 4 |
|
| 5 | +from enum import StrEnum |
5 | 6 | from typing import Any |
6 | 7 |
|
7 | 8 | import pytest |
@@ -142,6 +143,39 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par |
142 | 143 | assert "Validation failed as expected" in result.content[0].text |
143 | 144 | assert field_name in result.content[0].text |
144 | 145 |
|
| 146 | + # Test valid Enum types (should not fail validation) |
| 147 | + class Status(StrEnum): |
| 148 | + ACTIVE = "active" |
| 149 | + INACTIVE = "inactive" |
| 150 | + |
| 151 | + class ValidStrEnumSchema(BaseModel): |
| 152 | + status: Status = Field(description="Status using StrEnum") |
| 153 | + |
| 154 | + def create_valid_validation_tool(name: str, schema_class: type[BaseModel]): |
| 155 | + @mcp.tool(name=name, description=f"Tool testing {name}") |
| 156 | + async def tool(ctx: Context[ServerSession, None]) -> str: |
| 157 | + # This should succeed without validation error |
| 158 | + result = await ctx.elicit(message="Test valid schema", schema=schema_class) |
| 159 | + return f"Success: {result.action}" |
| 160 | + |
| 161 | + return tool |
| 162 | + |
| 163 | + create_valid_validation_tool("valid_strenum", ValidStrEnumSchema) |
| 164 | + |
| 165 | + async def enum_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): |
| 166 | + # Return the required status field |
| 167 | + return ElicitResult(action="accept", content={"status": "active"}) |
| 168 | + |
| 169 | + async with create_connected_server_and_client_session( |
| 170 | + mcp._mcp_server, elicitation_callback=enum_callback |
| 171 | + ) as client_session: |
| 172 | + await client_session.initialize() |
| 173 | + |
| 174 | + result = await client_session.call_tool("valid_strenum", {}) |
| 175 | + assert len(result.content) == 1 |
| 176 | + assert isinstance(result.content[0], TextContent) |
| 177 | + assert "Success: accept" == result.content[0].text |
| 178 | + |
145 | 179 |
|
146 | 180 | @pytest.mark.anyio |
147 | 181 | async def test_elicitation_with_optional_fields(): |
|
0 commit comments