diff --git a/README.md b/README.md index 546c436f0..3f6a37ef1 100644 --- a/README.md +++ b/README.md @@ -1515,18 +1515,21 @@ import temporalio.service class AuthenticationPlugin(Plugin): def __init__(self, api_key: str): self.api_key = api_key + + def init_client_plugin(self, next: Plugin) -> None: + self.next_client_plugin = next def configure_client(self, config: ClientConfig) -> ClientConfig: # Modify client configuration config["namespace"] = "my-secure-namespace" - return super().configure_client(config) + return self.next_client_plugin.configure_client(config) async def connect_service_client( self, config: temporalio.service.ConnectConfig ) -> temporalio.service.ServiceClient: # Add authentication to the connection config.api_key = self.api_key - return await super().connect_service_client(config) + return await self.next_client_plugin.connect_service_client(config) # Use the plugin when connecting client = await Client.connect( @@ -1538,31 +1541,55 @@ client = await Client.connect( #### Worker Plugins Worker plugins can modify worker configuration and intercept worker execution. They are useful for adding monitoring, -custom lifecycle management, or modifying worker settings. +custom lifecycle management, or modifying worker settings. Worker plugins can also configure replay. +They should do this in the case that they modified the worker in a way which would also need to be present +for replay to function. For instance, changing the data converter or adding workflows. Here's an example of a worker plugin that adds custom monitoring: ```python -from temporalio.worker import Plugin, WorkerConfig, Worker +import temporalio +from contextlib import asynccontextmanager +from typing import AsyncIterator +from temporalio.worker import Plugin, WorkerConfig, ReplayerConfig, Worker, Replayer, WorkflowReplayResult import logging class MonitoringPlugin(Plugin): def __init__(self): self.logger = logging.getLogger(__name__) + def init_worker_plugin(self, next: Plugin) -> None: + self.next_worker_plugin = next + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: # Modify worker configuration original_task_queue = config["task_queue"] config["task_queue"] = f"monitored-{original_task_queue}" self.logger.info(f"Worker created for task queue: {config['task_queue']}") - return super().configure_worker(config) + return self.next_worker_plugin.configure_worker(config) async def run_worker(self, worker: Worker) -> None: self.logger.info("Starting worker execution") try: - await super().run_worker(worker) + await self.next_worker_plugin.run_worker(worker) finally: self.logger.info("Worker execution completed") + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + return self.next_worker_plugin.configure_replayer(config) + + @asynccontextmanager + async def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[temporalio.client.WorkflowHistory], + ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: + self.logger.info("Starting replay execution") + try: + async with self.next_worker_plugin.run_replayer(replayer, histories) as results: + yield results + finally: + self.logger.info("Replay execution completed") # Use the plugin when creating a worker worker = Worker( @@ -1577,38 +1604,63 @@ worker = Worker( For plugins that need to work with both clients and workers, you can implement both interfaces in a single class: ```python +import temporalio +from contextlib import AbstractAsyncContextManager +from typing import AsyncIterator from temporalio.client import Plugin as ClientPlugin, ClientConfig -from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig +from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig, ReplayerConfig, Worker, Replayer, WorkflowReplayResult class UnifiedPlugin(ClientPlugin, WorkerPlugin): - def configure_client(self, config: ClientConfig) -> ClientConfig: - # Client-side customization - config["namespace"] = "unified-namespace" - return super().configure_client(config) - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - # Worker-side customization - config["max_cached_workflows"] = 500 - return super().configure_worker(config) - - async def run_worker(self, worker: Worker) -> None: - print("Starting unified worker") - await super().run_worker(worker) - + def init_client_plugin(self, next: ClientPlugin) -> None: + self.next_client_plugin = next + def init_worker_plugin(self, next: WorkerPlugin) -> None: + self.next_worker_plugin = next + + def configure_client(self, config: ClientConfig) -> ClientConfig: + # Client-side customization + config["data_converter"] = pydantic_data_converter + return self.next_client_plugin.configure_client(config) + + async def connect_service_client( + self, config: temporalio.service.ConnectConfig + ) -> temporalio.service.ServiceClient: + # Add authentication to the connection + config.api_key = self.api_key + return await self.next_client_plugin.connect_service_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + # Worker-side customization + return self.next_worker_plugin.configure_worker(config) + + async def run_worker(self, worker: Worker) -> None: + print("Starting unified worker") + await self.next_worker_plugin.run_worker(worker) + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + config["data_converter"] = pydantic_data_converter + return config + + async def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[temporalio.client.WorkflowHistory], + ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: + return self.next_worker_plugin.run_replayer(replayer, histories) + # Create client with the unified plugin client = await Client.connect( - "localhost:7233", - plugins=[UnifiedPlugin()] + "localhost:7233", + plugins=[UnifiedPlugin()] ) # Worker will automatically inherit the plugin from the client worker = Worker( - client, - task_queue="my-task-queue", - workflows=[MyWorkflow], - activities=[my_activity] + client, + task_queue="my-task-queue", + workflows=[MyWorkflow], + activities=[my_activity] ) ``` @@ -1617,7 +1669,7 @@ worker = Worker( - Plugins are executed in reverse order (last plugin wraps the first), forming a chain of responsibility - Client plugins that also implement worker plugin interfaces are automatically propagated to workers - Avoid providing the same plugin to both client and worker to prevent double execution -- Plugin methods should call `super()` to maintain the plugin chain +- Plugin methods should call the plugin provided during initialization to maintain the plugin chain - Each plugin's `name()` method returns a unique identifier for debugging purposes diff --git a/temporalio/client.py b/temporalio/client.py index b4b5453d7..6b206a912 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -191,7 +191,8 @@ async def connect( root_plugin: Plugin = _RootPlugin() for plugin in reversed(plugins): - root_plugin = plugin.init_client_plugin(root_plugin) + plugin.init_client_plugin(root_plugin) + root_plugin = plugin service_client = await root_plugin.connect_service_client(connect_config) @@ -235,7 +236,8 @@ def __init__( root_plugin: Plugin = _RootPlugin() for plugin in reversed(plugins): - root_plugin = plugin.init_client_plugin(root_plugin) + plugin.init_client_plugin(root_plugin) + root_plugin = plugin self._init_from_config(root_plugin.configure_client(config)) @@ -7398,22 +7400,21 @@ def name(self) -> str: """ return type(self).__module__ + "." + type(self).__qualname__ - def init_client_plugin(self, next: Plugin) -> Plugin: + @abstractmethod + def init_client_plugin(self, next: Plugin) -> None: """Initialize this plugin in the plugin chain. - This method sets up the chain of responsibility pattern by storing a reference + This method sets up the chain of responsibility pattern by providing a reference to the next plugin in the chain. It is called during client creation to build the plugin chain. Note, this may be called twice in the case of :py:meth:`connect`. + Implementations should store this reference and call the corresponding method + of the next plugin on method calls. Args: next: The next plugin in the chain to delegate to. - - Returns: - This plugin instance for method chaining. """ - self.next_client_plugin = next - return self + @abstractmethod def configure_client(self, config: ClientConfig) -> ClientConfig: """Hook called when creating a client to allow modification of configuration. @@ -7427,8 +7428,8 @@ def configure_client(self, config: ClientConfig) -> ClientConfig: Returns: The modified client configuration. """ - return self.next_client_plugin.configure_client(config) + @abstractmethod async def connect_service_client( self, config: temporalio.service.ConnectConfig ) -> temporalio.service.ServiceClient: @@ -7444,10 +7445,12 @@ async def connect_service_client( Returns: The connected service client. """ - return await self.next_client_plugin.connect_service_client(config) class _RootPlugin(Plugin): + def init_client_plugin(self, next: Plugin) -> None: + raise NotImplementedError() + def configure_client(self, config: ClientConfig) -> ClientConfig: return config diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index a1f71db71..21defd4a8 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -1,6 +1,6 @@ """Initialize Temporal OpenAI Agents overrides.""" -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from datetime import timedelta from typing import AsyncIterator, Callable, Optional, Union @@ -24,7 +24,7 @@ import temporalio.client import temporalio.worker -from temporalio.client import ClientConfig +from temporalio.client import ClientConfig, Plugin from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner @@ -41,7 +41,13 @@ from temporalio.converter import ( DataConverter, ) -from temporalio.worker import Worker, WorkerConfig +from temporalio.worker import ( + Replayer, + ReplayerConfig, + Worker, + WorkerConfig, + WorkflowReplayResult, +) @contextmanager @@ -231,6 +237,20 @@ def __init__( self._model_params = model_params self._model_provider = model_provider + def init_client_plugin(self, next: temporalio.client.Plugin) -> None: + """Set the next client plugin""" + self.next_client_plugin = next + + async def connect_service_client( + self, config: temporalio.service.ConnectConfig + ) -> temporalio.service.ServiceClient: + """No modifications to service client""" + return await self.next_client_plugin.connect_service_client(config) + + def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: + """Set the next worker plugin""" + self.next_worker_plugin = next + def configure_client(self, config: ClientConfig) -> ClientConfig: """Configure the Temporal client for OpenAI agents integration. @@ -246,7 +266,7 @@ def configure_client(self, config: ClientConfig) -> ClientConfig: config["data_converter"] = DataConverter( payload_converter_class=_OpenAIPayloadConverter ) - return super().configure_client(config) + return self.next_client_plugin.configure_client(config) def configure_worker(self, config: WorkerConfig) -> WorkerConfig: """Configure the Temporal worker for OpenAI agents integration. @@ -268,7 +288,7 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["activities"] = list(config.get("activities") or []) + [ ModelActivity(self._model_provider).invoke_model_activity ] - return super().configure_worker(config) + return self.next_worker_plugin.configure_worker(config) async def run_worker(self, worker: Worker) -> None: """Run the worker with OpenAI agents temporal overrides. @@ -281,4 +301,27 @@ async def run_worker(self, worker: Worker) -> None: worker: The worker instance to run. """ with set_open_ai_agent_temporal_overrides(self._model_params): - await super().run_worker(worker) + await self.next_worker_plugin.run_worker(worker) + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + """Configure the replayer for OpenAI Agents.""" + config["interceptors"] = list(config.get("interceptors") or []) + [ + OpenAIAgentsTracingInterceptor() + ] + config["data_converter"] = DataConverter( + payload_converter_class=_OpenAIPayloadConverter + ) + return config + + @asynccontextmanager + async def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[temporalio.client.WorkflowHistory], + ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: + """Set the OpenAI Overrides during replay""" + with set_open_ai_agent_temporal_overrides(self._model_params): + async with self.next_worker_plugin.run_replayer( + replayer, histories + ) as results: + yield results diff --git a/temporalio/worker/__init__.py b/temporalio/worker/__init__.py index 6e062afcc..08686dcb3 100644 --- a/temporalio/worker/__init__.py +++ b/temporalio/worker/__init__.py @@ -21,6 +21,7 @@ WorkflowInterceptorClassInput, WorkflowOutboundInterceptor, ) +from ._plugin import Plugin from ._replayer import ( Replayer, ReplayerConfig, @@ -44,7 +45,6 @@ WorkflowSlotInfo, ) from ._worker import ( - Plugin, PollerBehavior, PollerBehaviorAutoscaling, PollerBehaviorSimpleMaximum, diff --git a/temporalio/worker/_plugin.py b/temporalio/worker/_plugin.py new file mode 100644 index 000000000..0e696a2dd --- /dev/null +++ b/temporalio/worker/_plugin.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import abc +from contextlib import AbstractAsyncContextManager +from typing import TYPE_CHECKING, AsyncIterator + +from temporalio.client import WorkflowHistory + +if TYPE_CHECKING: + from temporalio.worker import ( + Replayer, + ReplayerConfig, + Worker, + WorkerConfig, + WorkflowReplayResult, + ) + + +class Plugin(abc.ABC): + """Base class for worker plugins that can intercept and modify worker behavior. + + Plugins allow customization of worker creation and execution processes + through a chain of responsibility pattern. Each plugin can modify the worker + configuration or intercept worker execution. + + WARNING: This is an experimental feature and may change in the future. + """ + + def name(self) -> str: + """Get the qualified name of this plugin. Can be overridden if desired to provide a more appropriate name. + + Returns: + The fully qualified name of the plugin class (module.classname). + """ + return type(self).__module__ + "." + type(self).__qualname__ + + @abc.abstractmethod + def init_worker_plugin(self, next: Plugin) -> None: + """Initialize this plugin in the plugin chain. + + This method sets up the chain of responsibility pattern by providing a reference + to the next plugin in the chain. It is called during worker creation to build + the plugin chain. Implementations should store this reference and call the corresponding method + of the next plugin on method calls. + + Args: + next: The next plugin in the chain to delegate to. + """ + + @abc.abstractmethod + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + """Hook called when creating a worker to allow modification of configuration. + + This method is called during worker creation and allows plugins to modify + the worker configuration before the worker is fully initialized. Plugins + can modify task queue names, adjust concurrency settings, add interceptors, + or change other worker settings. + + Args: + config: The worker configuration dictionary to potentially modify. + + Returns: + The modified worker configuration. + """ + + @abc.abstractmethod + async def run_worker(self, worker: Worker) -> None: + """Hook called when running a worker to allow interception of execution. + + This method is called when the worker is started and allows plugins to + intercept or wrap the worker execution. Plugins can add monitoring, + custom lifecycle management, or other execution-time behavior. + + Args: + worker: The worker instance to run. + """ + + @abc.abstractmethod + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + """Hook called when creating a replayer to allow modification of configuration. + + This should be used to configure anything in ReplayerConfig needed to make execution match + the worker and client config. This could include interceptors, DataConverter, workflows, and more. + + Args: + config: The replayer configuration dictionary to potentially modify. + + Returns: + The modified replayer configuration. + """ + + @abc.abstractmethod + def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[WorkflowHistory], + ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: + """Hook called when running a replayer to allow interception of execution.""" + + +class _RootPlugin(Plugin): + def init_worker_plugin(self, next: Plugin) -> None: + raise NotImplementedError() + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + return config + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + return config + + async def run_worker(self, worker: Worker) -> None: + await worker._run() + + def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[WorkflowHistory], + ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: + return replayer._workflow_replay_iterator(histories) diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 6e9761b58..240429bf7 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -5,7 +5,7 @@ import asyncio import concurrent.futures import logging -from contextlib import asynccontextmanager +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type @@ -21,6 +21,7 @@ from ..common import HeaderCodecBehavior from ._interceptor import Interceptor +from ._plugin import _RootPlugin from ._worker import load_default_build_id from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner @@ -42,6 +43,7 @@ def __init__( namespace: str = "ReplayNamespace", data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, interceptors: Sequence[Interceptor] = [], + plugins: Sequence[temporalio.worker.Plugin] = [], build_id: Optional[str] = None, identity: Optional[str] = None, workflow_failure_exception_types: Sequence[Type[BaseException]] = [], @@ -62,8 +64,6 @@ def __init__( will be shared across all replay calls and never explicitly shut down. Users are encouraged to provide their own if needing more control. """ - if not workflows: - raise ValueError("At least one workflow must be specified") self._config = ReplayerConfig( workflows=list(workflows), workflow_task_executor=( @@ -83,6 +83,18 @@ def __init__( header_codec_behavior=header_codec_behavior, ) + # Apply plugin configuration + root_plugin: temporalio.worker.Plugin = _RootPlugin() + for plugin in reversed(plugins): + plugin.init_worker_plugin(root_plugin) + root_plugin = plugin + self._config = root_plugin.configure_replayer(self._config) + self._plugin = root_plugin + + # Validate workflows after plugin configuration + if not self._config["workflows"]: + raise ValueError("At least one workflow must be specified") + def config(self) -> ReplayerConfig: """Config, as a dictionary, used to create this replayer. @@ -149,10 +161,9 @@ async def replay_workflows( replay_failures[result.history.run_id] = result.replay_failure return WorkflowReplayResults(replay_failures=replay_failures) - @asynccontextmanager - async def workflow_replay_iterator( + def workflow_replay_iterator( self, histories: AsyncIterator[temporalio.client.WorkflowHistory] - ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: + ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: """Replay workflows for the given histories. This is a context manager for use via ``async with``. The value is an @@ -165,6 +176,12 @@ async def workflow_replay_iterator( An async iterator that returns replayed workflow results as they are replayed. """ + return self._plugin.run_replayer(self, histories) + + @asynccontextmanager + async def _workflow_replay_iterator( + self, histories: AsyncIterator[temporalio.client.WorkflowHistory] + ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: try: last_replay_failure: Optional[Exception] last_replay_complete = asyncio.Event() diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 58f881c04..f93848496 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -2,7 +2,6 @@ from __future__ import annotations -import abc import asyncio import concurrent.futures import hashlib @@ -39,6 +38,7 @@ from ._activity import SharedStateManager, _ActivityWorker from ._interceptor import Interceptor from ._nexus import _NexusWorker +from ._plugin import Plugin, _RootPlugin from ._tuning import WorkerTuner from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner @@ -89,75 +89,6 @@ def _to_bridge(self) -> temporalio.bridge.worker.PollerBehavior: ] -class Plugin(abc.ABC): - """Base class for worker plugins that can intercept and modify worker behavior. - - Plugins allow customization of worker creation and execution processes - through a chain of responsibility pattern. Each plugin can modify the worker - configuration or intercept worker execution. - """ - - def name(self) -> str: - """Get the qualified name of this plugin. Can be overridden if desired to provide a more appropriate name. - - Returns: - The fully qualified name of the plugin class (module.classname). - """ - return type(self).__module__ + "." + type(self).__qualname__ - - def init_worker_plugin(self, next: Plugin) -> Plugin: - """Initialize this plugin in the plugin chain. - - This method sets up the chain of responsibility pattern by storing a reference - to the next plugin in the chain. It is called during worker creation to build - the plugin chain. - - Args: - next: The next plugin in the chain to delegate to. - - Returns: - This plugin instance for method chaining. - """ - self.next_worker_plugin = next - return self - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - """Hook called when creating a worker to allow modification of configuration. - - This method is called during worker creation and allows plugins to modify - the worker configuration before the worker is fully initialized. Plugins - can modify task queue names, adjust concurrency settings, add interceptors, - or change other worker settings. - - Args: - config: The worker configuration dictionary to potentially modify. - - Returns: - The modified worker configuration. - """ - return self.next_worker_plugin.configure_worker(config) - - async def run_worker(self, worker: Worker) -> None: - """Hook called when running a worker to allow interception of execution. - - This method is called when the worker is started and allows plugins to - intercept or wrap the worker execution. Plugins can add monitoring, - custom lifecycle management, or other execution-time behavior. - - Args: - worker: The worker instance to run. - """ - await self.next_worker_plugin.run_worker(worker) - - -class _RootPlugin(Plugin): - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - return config - - async def run_worker(self, worker: Worker) -> None: - await worker._run() - - class Worker: """Worker to process workflows and/or activities. @@ -443,7 +374,8 @@ def __init__( root_plugin: Plugin = _RootPlugin() for plugin in reversed(plugins): - root_plugin = plugin.init_worker_plugin(root_plugin) + plugin.init_worker_plugin(root_plugin) + root_plugin = plugin config = root_plugin.configure_worker(config) self._plugin = root_plugin diff --git a/tests/contrib/openai_agents/test_openai_replay.py b/tests/contrib/openai_agents/test_openai_replay.py index d3ac92c5e..d625343b8 100644 --- a/tests/contrib/openai_agents/test_openai_replay.py +++ b/tests/contrib/openai_agents/test_openai_replay.py @@ -3,14 +3,7 @@ import pytest from temporalio.client import WorkflowHistory -from temporalio.contrib.openai_agents import ModelActivityParameters -from temporalio.contrib.openai_agents._temporal_openai_agents import ( - set_open_ai_agent_temporal_overrides, -) -from temporalio.contrib.openai_agents._trace_interceptor import ( - OpenAIAgentsTracingInterceptor, -) -from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin from temporalio.worker import Replayer from tests.contrib.openai_agents.test_openai import ( AgentsAsToolsWorkflow, @@ -39,17 +32,15 @@ async def test_replay(file_name: str) -> None: with (Path(__file__).with_name("histories") / file_name).open("r") as f: history_json = f.read() - with set_open_ai_agent_temporal_overrides(ModelActivityParameters()): - await Replayer( - workflows=[ - ResearchWorkflow, - ToolsWorkflow, - CustomerServiceWorkflow, - AgentsAsToolsWorkflow, - HelloWorldAgent, - InputGuardrailWorkflow, - OutputGuardrailWorkflow, - ], - data_converter=pydantic_data_converter, - interceptors=[OpenAIAgentsTracingInterceptor()], - ).replay_workflow(WorkflowHistory.from_json("fake", history_json)) + await Replayer( + workflows=[ + ResearchWorkflow, + ToolsWorkflow, + CustomerServiceWorkflow, + AgentsAsToolsWorkflow, + HelloWorldAgent, + InputGuardrailWorkflow, + OutputGuardrailWorkflow, + ], + plugins=[OpenAIAgentsPlugin()], + ).replay_workflow(WorkflowHistory.from_json("fake", history_json)) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 4a60bba4d..eb08bba2d 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1,15 +1,26 @@ import dataclasses +import uuid import warnings -from typing import cast +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import AsyncIterator, cast import pytest import temporalio.client import temporalio.worker -from temporalio.client import Client, ClientConfig, OutboundInterceptor +from temporalio import workflow +from temporalio.client import Client, ClientConfig, OutboundInterceptor, Plugin +from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.testing import WorkflowEnvironment -from temporalio.worker import Worker, WorkerConfig +from temporalio.worker import ( + Replayer, + ReplayerConfig, + Worker, + WorkerConfig, + WorkflowReplayResult, +) from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner +from tests.helpers import new_worker from tests.worker.test_worker import never_run_activity @@ -26,18 +37,21 @@ class MyClientPlugin(temporalio.client.Plugin): def __init__(self): self.interceptor = TestClientInterceptor() + def init_client_plugin(self, next: Plugin) -> None: + self.next_client_plugin = next + def configure_client(self, config: ClientConfig) -> ClientConfig: config["namespace"] = "replaced_namespace" config["interceptors"] = list(config.get("interceptors") or []) + [ self.interceptor ] - return super().configure_client(config) + return self.next_client_plugin.configure_client(config) async def connect_service_client( self, config: temporalio.service.ConnectConfig ) -> temporalio.service.ServiceClient: config.api_key = "replaced key" - return await super().connect_service_client(config) + return await self.next_client_plugin.connect_service_client(config) async def test_client_plugin(client: Client, env: WorkflowEnvironment): @@ -59,12 +73,42 @@ async def test_client_plugin(client: Client, env: WorkflowEnvironment): class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): + def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: + self.next_worker_plugin = next + + def init_client_plugin(self, next: temporalio.client.Plugin) -> None: + self.next_client_plugin = next + + def configure_client(self, config: ClientConfig) -> ClientConfig: + return self.next_client_plugin.configure_client(config) + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "combined" - return super().configure_worker(config) + return self.next_worker_plugin.configure_worker(config) + + async def connect_service_client( + self, config: temporalio.service.ConnectConfig + ) -> temporalio.service.ServiceClient: + return await self.next_client_plugin.connect_service_client(config) + + async def run_worker(self, worker: Worker) -> None: + await self.next_worker_plugin.run_worker(worker) + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + return self.next_worker_plugin.configure_replayer(config) + + def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[temporalio.client.WorkflowHistory], + ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: + return self.next_worker_plugin.run_replayer(replayer, histories) class MyWorkerPlugin(temporalio.worker.Plugin): + def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: + self.next_worker_plugin = next + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "replaced_queue" runner = config.get("workflow_runner") @@ -73,10 +117,20 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: runner, restrictions=runner.restrictions.with_passthrough_modules("my_module"), ) - return super().configure_worker(config) + return self.next_worker_plugin.configure_worker(config) async def run_worker(self, worker: Worker) -> None: - await super().run_worker(worker) + await self.next_worker_plugin.run_worker(worker) + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + return self.next_worker_plugin.configure_replayer(config) + + def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[temporalio.client.WorkflowHistory], + ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: + return self.next_worker_plugin.run_replayer(replayer, histories) async def test_worker_plugin_basic_config(client: Client) -> None: @@ -136,3 +190,69 @@ async def test_worker_sandbox_restrictions(client: Client) -> None: SandboxedWorkflowRunner, worker.config().get("workflow_runner") ).restrictions.passthrough_modules ) + + +class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): + def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: + self.next_worker_plugin = next + + def init_client_plugin(self, next: temporalio.client.Plugin) -> None: + self.next_client_plugin = next + + def configure_client(self, config: ClientConfig) -> ClientConfig: + config["data_converter"] = pydantic_data_converter + return self.next_client_plugin.configure_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] + return self.next_worker_plugin.configure_worker(config) + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + config["data_converter"] = pydantic_data_converter + config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] + return self.next_worker_plugin.configure_replayer(config) + + async def run_worker(self, worker: Worker) -> None: + await self.next_worker_plugin.run_worker(worker) + + async def connect_service_client( + self, config: temporalio.service.ConnectConfig + ) -> temporalio.service.ServiceClient: + return await self.next_client_plugin.connect_service_client(config) + + @asynccontextmanager + async def run_replayer( + self, + replayer: Replayer, + histories: AsyncIterator[temporalio.client.WorkflowHistory], + ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: + async with self.next_worker_plugin.run_replayer(replayer, histories) as result: + yield result + + +@workflow.defn +class HelloWorkflow: + @workflow.run + async def run(self, name: str) -> str: + return f"Hello, {name}!" + + +async def test_replay(client: Client) -> None: + plugin = ReplayCheckPlugin() + new_config = client.config() + new_config["plugins"] = [plugin] + client = Client(**new_config) + + async with new_worker(client) as worker: + handle = await client.start_workflow( + HelloWorkflow.run, + "Tim", + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.result() + replayer = Replayer(workflows=[], plugins=[plugin]) + assert len(replayer.config().get("workflows") or []) == 1 + assert replayer.config().get("data_converter") == pydantic_data_converter + + await replayer.replay_workflow(await handle.fetch_history())