Skip to content

Commit 57bfab2

Browse files
authored
Merge branch 'pydantic:main' into patch-1
2 parents 0c990ee + cedee4a commit 57bfab2

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

pydantic_ai_slim/pydantic_ai/ui/_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from abc import ABC, abstractmethod
44
from collections.abc import AsyncIterator, Sequence
5-
from dataclasses import KW_ONLY, Field, dataclass, replace
5+
from dataclasses import KW_ONLY, Field, dataclass
66
from functools import cached_property
77
from http import HTTPStatus
88
from typing import (
@@ -238,7 +238,7 @@ def run_stream_native(
238238
else:
239239
state = raw_state
240240

241-
deps = replace(deps, state=state)
241+
deps.state = state
242242
elif self.state:
243243
raise UserError(
244244
f'State is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'

pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from collections.abc import Callable, Mapping, Sequence
6+
from dataclasses import replace
67
from typing import Any, Generic
78

89
from typing_extensions import Self
@@ -18,7 +19,7 @@
1819
from pydantic_ai.toolsets import AbstractToolset
1920
from pydantic_ai.usage import RunUsage, UsageLimits
2021

21-
from .. import OnCompleteFunc
22+
from .. import OnCompleteFunc, StateHandler
2223
from ._adapter import AGUIAdapter
2324

2425
try:
@@ -121,6 +122,12 @@ def __init__(
121122

122123
async def run_agent(request: Request) -> Response:
123124
"""Endpoint to run the agent with the provided input data."""
125+
# `dispatch_request` will store the frontend state from the request on `deps.state` (if it implements the `StateHandler` protocol),
126+
# so we need to copy the deps to avoid different requests mutating the same deps object.
127+
nonlocal deps
128+
if isinstance(deps, StateHandler): # pragma: no branch
129+
deps = replace(deps)
130+
124131
return await AGUIAdapter[AgentDepsT, OutputDataT].dispatch_request(
125132
request,
126133
agent=agent,

tests/test_ag_ui.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,15 +1154,21 @@ async def store_state(
11541154
),
11551155
]
11561156

1157-
deps = StateDeps(StateInt(value=0))
1157+
seen_deps_states: list[int] = []
11581158

11591159
for run_input in run_inputs:
11601160
events = list[dict[str, Any]]()
1161-
async for event in run_ag_ui(agent, run_input, deps=deps):
1161+
deps = StateDeps(StateInt(value=0))
1162+
1163+
async def on_complete(result: AgentRunResult[Any]):
1164+
seen_deps_states.append(deps.state.value)
1165+
1166+
async for event in run_ag_ui(agent, run_input, deps=deps, on_complete=on_complete):
11621167
events.append(json.loads(event.removeprefix('data: ')))
11631168

11641169
assert events == simple_result()
11651170
assert seen_states == snapshot([41, 0, 0, 42])
1171+
assert seen_deps_states == snapshot([42, 1, 1, 43])
11661172

11671173

11681174
async def test_request_with_state_without_handler() -> None:
@@ -1275,8 +1281,10 @@ async def get_state(ctx: RunContext[StateDeps[StateInt]]) -> int:
12751281
async def test_to_ag_ui() -> None:
12761282
"""Test the agent.to_ag_ui method."""
12771283

1278-
agent = Agent(model=FunctionModel(stream_function=simple_stream))
1279-
app = agent.to_ag_ui()
1284+
agent = Agent(model=FunctionModel(stream_function=simple_stream), deps_type=StateDeps[StateInt])
1285+
1286+
deps = StateDeps(StateInt(value=0))
1287+
app = agent.to_ag_ui(deps=deps)
12801288
async with LifespanManager(app):
12811289
transport = httpx.ASGITransport(app)
12821290
async with httpx.AsyncClient(transport=transport) as client:
@@ -1286,6 +1294,7 @@ async def test_to_ag_ui() -> None:
12861294
id='msg_1',
12871295
content='Hello, world!',
12881296
),
1297+
state=StateInt(value=42),
12891298
)
12901299
async with client.stream(
12911300
'POST',
@@ -1301,6 +1310,9 @@ async def test_to_ag_ui() -> None:
13011310

13021311
assert events == simple_result()
13031312

1313+
# Verify the state was not mutated by the run
1314+
assert deps.state.value == 0
1315+
13041316

13051317
async def test_callback_sync() -> None:
13061318
"""Test that sync callbacks work correctly."""

0 commit comments

Comments
 (0)