From b37aa26b2afa2469d0f65aa9b7226787588e61bc Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 21 Oct 2025 23:22:14 +0000 Subject: [PATCH] Bump temporalio and use SimplePlugin --- .../durable_exec/temporal/__init__.py | 161 ++++++------------ .../durable_exec/temporal/_logfire.py | 28 ++- pyproject.toml | 1 + uv.lock | 14 +- 4 files changed, 72 insertions(+), 132 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py index b982527345..b711157c89 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py @@ -1,24 +1,14 @@ from __future__ import annotations import warnings -from collections.abc import AsyncIterator, Callable, Sequence -from contextlib import AbstractAsyncContextManager from dataclasses import replace from typing import Any from pydantic.errors import PydanticUserError -from temporalio.client import ClientConfig, Plugin as ClientPlugin, WorkflowHistory from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter from temporalio.converter import DataConverter, DefaultPayloadConverter -from temporalio.service import ConnectConfig, ServiceClient -from temporalio.worker import ( - Plugin as WorkerPlugin, - Replayer, - ReplayerConfig, - Worker, - WorkerConfig, - WorkflowReplayResult, -) +from temporalio.plugin import SimplePlugin +from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner from ...exceptions import UserError @@ -37,102 +27,61 @@ ] -class PydanticAIPlugin(ClientPlugin, WorkerPlugin): +def _data_converter(converter: DataConverter | None) -> DataConverter: + if converter and converter.payload_converter_class not in ( + DefaultPayloadConverter, + PydanticPayloadConverter, + ): + warnings.warn( # pragma: no cover + 'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.' + ) + + return pydantic_data_converter + + +def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: + if not runner: + raise ValueError('No WorkflowRunner provided to the Pydantic AI plugin.') + + if isinstance(runner, SandboxedWorkflowRunner): + return replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + 'pydantic_ai', + 'pydantic', + 'pydantic_core', + 'logfire', + 'rich', + 'httpx', + 'anyio', + 'httpcore', + # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize + 'attrs', + # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize + 'numpy', + 'pandas', + ), + ) + return runner + + +class PydanticAIPlugin(SimplePlugin): """Temporal client and worker plugin for Pydantic AI.""" - def init_client_plugin(self, next: ClientPlugin) -> None: - self.next_client_plugin = next - - def init_worker_plugin(self, next: WorkerPlugin) -> None: - self.next_worker_plugin = next - - def configure_client(self, config: ClientConfig) -> ClientConfig: - config['data_converter'] = self._get_new_data_converter(config.get('data_converter')) - return self.next_client_plugin.configure_client(config) - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType] - if isinstance(runner, SandboxedWorkflowRunner): # pragma: no branch - config['workflow_runner'] = replace( - runner, - restrictions=runner.restrictions.with_passthrough_modules( - 'pydantic_ai', - 'pydantic', - 'pydantic_core', - 'logfire', - 'rich', - 'httpx', - 'anyio', - 'httpcore', - # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize - 'attrs', - # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize - 'numpy', - 'pandas', - ), - ) - - config['workflow_failure_exception_types'] = [ - *config.get('workflow_failure_exception_types', []), # pyright: ignore[reportUnknownMemberType] - UserError, - PydanticUserError, - ] - - return self.next_worker_plugin.configure_worker(config) - - async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: - return await self.next_client_plugin.connect_service_client(config) - - async def run_worker(self, worker: Worker) -> None: - await self.next_worker_plugin.run_worker(worker) - - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover - config['data_converter'] = self._get_new_data_converter(config.get('data_converter')) # pyright: ignore[reportUnknownMemberType] - return self.next_worker_plugin.configure_replayer(config) - - def run_replayer( - self, - replayer: Replayer, - histories: AsyncIterator[WorkflowHistory], - ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover - return self.next_worker_plugin.run_replayer(replayer, histories) - - def _get_new_data_converter(self, converter: DataConverter | None) -> DataConverter: - if converter and converter.payload_converter_class not in ( - DefaultPayloadConverter, - PydanticPayloadConverter, - ): - warnings.warn( # pragma: no cover - 'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.' - ) - - return pydantic_data_converter - - -class AgentPlugin(WorkerPlugin): - """Temporal worker plugin for a specific Pydantic AI agent.""" - - def __init__(self, agent: TemporalAgent[Any, Any]): - self.agent = agent - - def init_worker_plugin(self, next: WorkerPlugin) -> None: - self.next_worker_plugin = next + def __init__(self): + super().__init__( # type: ignore[reportUnknownMemberType] + name='PydanticAIPlugin', + data_converter=_data_converter, + workflow_runner=_workflow_runner, + workflow_failure_exception_types=[UserError, PydanticUserError], + ) - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] - # Activities are checked for name conflicts by Temporal. - config['activities'] = [*activities, *self.agent.temporal_activities] - return self.next_worker_plugin.configure_worker(config) - async def run_worker(self, worker: Worker) -> None: - await self.next_worker_plugin.run_worker(worker) - - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover - return self.next_worker_plugin.configure_replayer(config) +class AgentPlugin(SimplePlugin): + """Temporal worker plugin for a specific Pydantic AI agent.""" - def run_replayer( - self, - replayer: Replayer, - histories: AsyncIterator[WorkflowHistory], - ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover - return self.next_worker_plugin.run_replayer(replayer, histories) + def __init__(self, agent: TemporalAgent[Any, Any]): + super().__init__( # type: ignore[reportUnknownMemberType] + name='AgentPlugin', + activities=agent.temporal_activities, + ) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_logfire.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_logfire.py index 055567e4d6..1c1c8f8a08 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_logfire.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_logfire.py @@ -1,9 +1,9 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING -from temporalio.client import ClientConfig, Plugin as ClientPlugin +from temporalio.plugin import SimplePlugin from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig from temporalio.service import ConnectConfig, ServiceClient @@ -19,12 +19,14 @@ def _default_setup_logfire() -> Logfire: return instance -class LogfirePlugin(ClientPlugin): +class LogfirePlugin(SimplePlugin): """Temporal client plugin for Logfire.""" def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire, *, metrics: bool = True): try: import logfire # noqa: F401 # pyright: ignore[reportUnusedImport] + from opentelemetry.trace import get_tracer + from temporalio.contrib.opentelemetry import TracingInterceptor except ImportError as _import_error: raise ImportError( 'Please install the `logfire` package to use the Logfire plugin, ' @@ -34,18 +36,14 @@ def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire self.setup_logfire = setup_logfire self.metrics = metrics - def init_client_plugin(self, next: ClientPlugin) -> None: - self.next_client_plugin = next + super().__init__( # type: ignore[reportUnknownMemberType] + name='LogfirePlugin', + client_interceptors=[TracingInterceptor(get_tracer('temporalio'))], + ) - def configure_client(self, config: ClientConfig) -> ClientConfig: - from opentelemetry.trace import get_tracer - from temporalio.contrib.opentelemetry import TracingInterceptor - - interceptors = config.get('interceptors', []) - config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))] - return self.next_client_plugin.configure_client(config) - - async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: + async def connect_service_client( + self, config: ConnectConfig, next: Callable[[ConnectConfig], Awaitable[ServiceClient]] + ) -> ServiceClient: logfire = self.setup_logfire() if self.metrics: @@ -60,4 +58,4 @@ async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers)) ) - return await self.next_client_plugin.connect_service_client(config) + return await next(config) diff --git a/pyproject.toml b/pyproject.toml index b09a172045..fd292d1d6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ pydantic-ai-slim = { workspace = true } pydantic-evals = { workspace = true } pydantic-graph = { workspace = true } pydantic-ai-examples = { workspace = true } +temporalio = { git = "https://github.com/temporalio/sdk-python.git", rev = "main" } [tool.uv.workspace] members = [ diff --git a/uv.lock b/uv.lock index 486ea9bbd3..0075661bde 100644 --- a/uv.lock +++ b/uv.lock @@ -3921,7 +3921,7 @@ requires-dist = [ { name = "rich", marker = "extra == 'cli'", specifier = ">=13" }, { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, - { name = "temporalio", marker = "extra == 'temporal'", specifier = "==1.18.0" }, + { name = "temporalio", marker = "extra == 'temporal'", git = "https://github.com/temporalio/sdk-python.git?rev=main" }, { name = "tenacity", marker = "extra == 'retries'", specifier = ">=8.2.3" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] @@ -4938,8 +4938,8 @@ wheels = [ [[package]] name = "temporalio" -version = "1.18.0" -source = { registry = "https://pypi.org/simple" } +version = "1.18.1" +source = { git = "https://github.com/temporalio/sdk-python.git?rev=main#f03ddc2136c56acda77387627f9b379166d06d03" } dependencies = [ { name = "nexus-rpc" }, { name = "protobuf" }, @@ -4947,14 +4947,6 @@ dependencies = [ { name = "types-protobuf" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7e/20/b52c96b37bf00ead6e8a4a197075770ebad516db765cc3abca8396de0689/temporalio-1.18.0.tar.gz", hash = "sha256:7ff7f833eb1e7697084b4ed9d86c3167cbff1ec77f1b40df774313a5d0fd5f6d", size = 1781572, upload-time = "2025-09-19T23:40:52.511Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/28/c5a4ee259748450ac0765837f8c78cbfa36800264158d98bd2cde4496d87/temporalio-1.18.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:ac5d30d8b010c9b042065ea1259da7638db1a0a25e81ee4be0671a393ed329c5", size = 12734753, upload-time = "2025-09-19T23:40:06.575Z" }, - { url = "https://files.pythonhosted.org/packages/be/94/24bd903b5594420a4d131bfa3de965313f9a409af77b47e9a9a56d85bb9e/temporalio-1.18.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:19315d192247230c9bd7c60a566c2b3a80ad4d9de891c6aa13df63d72d3ec169", size = 12323141, upload-time = "2025-09-19T23:40:16.817Z" }, - { url = "https://files.pythonhosted.org/packages/6d/76/82415b43c68e2c6bb3a85e8800555d206767815088c8cad0ade9a06bd7ac/temporalio-1.18.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a023b25033e48b2e43f623a78737047a45b8cb553f69f457d09272fce5c723da", size = 12694061, upload-time = "2025-09-19T23:40:26.388Z" }, - { url = "https://files.pythonhosted.org/packages/41/60/176a3224c2739fee270052dd9224ae36370c4e13d2ab1bb96a2f9bbb513c/temporalio-1.18.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:695211dddbcffc20077d5b3b9a9b41bd09f60393c4ff211bcc7d6d895d607cc1", size = 12879404, upload-time = "2025-09-19T23:40:37.487Z" }, - { url = "https://files.pythonhosted.org/packages/e3/8d/e3809b356262d1d398d8cbb78df1e19d460c0a89e6ab64ca8d9c05d5fe5a/temporalio-1.18.0-cp39-abi3-win_amd64.whl", hash = "sha256:e3f691bd0a01a22c0fe40e87b6236cc8a292628e3a5a490880d1bf94709249c9", size = 13088041, upload-time = "2025-09-19T23:40:49.469Z" }, -] [[package]] name = "tenacity"