Skip to content

Commit 5e93e63

Browse files
authored
💥 Plugin Overhaul (#1139)
* Add static plugin constructor * Allow callables as parameters * Convert OpenAI plugin as example * Delay openai client creation * Change plugin structure to remove initializers * PR feedback - exposing SimplePlugin type * Early check OpenAI connectivity if outside a workflow * PR Feedback * Don't register activities in replayer * Delay activity construction until needed * Linting * Linting * Simplify statement
1 parent 5ca4861 commit 5e93e63

File tree

9 files changed

+513
-263
lines changed

9 files changed

+513
-263
lines changed

temporalio/client.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,13 @@
7070
WorkflowSerializationContext,
7171
)
7272
from temporalio.service import (
73+
ConnectConfig,
7374
HttpConnectProxyConfig,
7475
KeepAliveConfig,
7576
RetryConfig,
7677
RPCError,
7778
RPCStatusCode,
79+
ServiceClient,
7880
TLSConfig,
7981
)
8082

@@ -198,12 +200,14 @@ async def connect(
198200
http_connect_proxy_config=http_connect_proxy_config,
199201
)
200202

201-
root_plugin: Plugin = _RootPlugin()
203+
def make_lambda(plugin, next):
204+
return lambda config: plugin.connect_service_client(config, next)
205+
206+
next_function = ServiceClient.connect
202207
for plugin in reversed(plugins):
203-
plugin.init_client_plugin(root_plugin)
204-
root_plugin = plugin
208+
next_function = make_lambda(plugin, next_function)
205209

206-
service_client = await root_plugin.connect_service_client(connect_config)
210+
service_client = await next_function(connect_config)
207211

208212
return Client(
209213
service_client,
@@ -243,12 +247,10 @@ def __init__(
243247
plugins=plugins,
244248
)
245249

246-
root_plugin: Plugin = _RootPlugin()
247-
for plugin in reversed(plugins):
248-
plugin.init_client_plugin(root_plugin)
249-
root_plugin = plugin
250+
for plugin in plugins:
251+
config = plugin.configure_client(config)
250252

251-
self._init_from_config(root_plugin.configure_client(config))
253+
self._init_from_config(config)
252254

253255
def _init_from_config(self, config: ClientConfig):
254256
self._config = config
@@ -7541,20 +7543,6 @@ def name(self) -> str:
75417543
"""
75427544
return type(self).__module__ + "." + type(self).__qualname__
75437545

7544-
@abstractmethod
7545-
def init_client_plugin(self, next: Plugin) -> None:
7546-
"""Initialize this plugin in the plugin chain.
7547-
7548-
This method sets up the chain of responsibility pattern by providing a reference
7549-
to the next plugin in the chain. It is called during client creation to build
7550-
the plugin chain. Note, this may be called twice in the case of :py:meth:`connect`.
7551-
Implementations should store this reference and call the corresponding method
7552-
of the next plugin on method calls.
7553-
7554-
Args:
7555-
next: The next plugin in the chain to delegate to.
7556-
"""
7557-
75587546
@abstractmethod
75597547
def configure_client(self, config: ClientConfig) -> ClientConfig:
75607548
"""Hook called when creating a client to allow modification of configuration.
@@ -7572,8 +7560,10 @@ def configure_client(self, config: ClientConfig) -> ClientConfig:
75727560

75737561
@abstractmethod
75747562
async def connect_service_client(
7575-
self, config: temporalio.service.ConnectConfig
7576-
) -> temporalio.service.ServiceClient:
7563+
self,
7564+
config: ConnectConfig,
7565+
next: Callable[[ConnectConfig], Awaitable[ServiceClient]],
7566+
) -> ServiceClient:
75777567
"""Hook called when connecting to the Temporal service.
75787568
75797569
This method is called during service client connection and allows plugins
@@ -7586,16 +7576,3 @@ async def connect_service_client(
75867576
Returns:
75877577
The connected service client.
75887578
"""
7589-
7590-
7591-
class _RootPlugin(Plugin):
7592-
def init_client_plugin(self, next: Plugin) -> None:
7593-
raise NotImplementedError()
7594-
7595-
def configure_client(self, config: ClientConfig) -> ClientConfig:
7596-
return config
7597-
7598-
async def connect_service_client(
7599-
self, config: temporalio.service.ConnectConfig
7600-
) -> temporalio.service.ServiceClient:
7601-
return await temporalio.service.ServiceClient.connect(config)

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from pydantic_core import to_json
3737
from typing_extensions import Required, TypedDict
3838

39-
from temporalio import activity
39+
from temporalio import activity, workflow
4040
from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater
4141
from temporalio.exceptions import ApplicationError
4242

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 52 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@
2222
from agents.run import get_default_agent_runner, set_default_agent_runner
2323
from agents.tracing import get_trace_provider
2424
from agents.tracing.provider import DefaultTraceProvider
25-
from openai.types.responses import ResponsePromptParam
2625

27-
import temporalio.client
28-
import temporalio.worker
29-
from temporalio.client import ClientConfig
3026
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
3127
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
3228
from temporalio.contrib.openai_agents._openai_runner import (
@@ -47,13 +43,8 @@
4743
DataConverter,
4844
DefaultPayloadConverter,
4945
)
50-
from temporalio.worker import (
51-
Replayer,
52-
ReplayerConfig,
53-
Worker,
54-
WorkerConfig,
55-
WorkflowReplayResult,
56-
)
46+
from temporalio.plugin import SimplePlugin
47+
from temporalio.worker import WorkflowRunner
5748
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
5849

5950
# Unsupported on python 3.9
@@ -172,7 +163,21 @@ def __init__(self) -> None:
172163
super().__init__(ToJsonOptions(exclude_unset=True))
173164

174165

175-
class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin):
166+
def _data_converter(converter: Optional[DataConverter]) -> DataConverter:
167+
if converter is None:
168+
return DataConverter(payload_converter_class=OpenAIPayloadConverter)
169+
elif converter.payload_converter_class is DefaultPayloadConverter:
170+
return dataclasses.replace(
171+
converter, payload_converter_class=OpenAIPayloadConverter
172+
)
173+
elif not isinstance(converter.payload_converter, OpenAIPayloadConverter):
174+
raise ValueError(
175+
"The payload converter must be of type OpenAIPayloadConverter."
176+
)
177+
return converter
178+
179+
180+
class OpenAIAgentsPlugin(SimplePlugin):
176181
"""Temporal plugin for integrating OpenAI agents with Temporal workflows.
177182
178183
.. warning::
@@ -278,127 +283,48 @@ def __init__(
278283
"When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout"
279284
)
280285

281-
self._model_params = model_params
282-
self._model_provider = model_provider
283-
self._mcp_server_providers = mcp_server_providers
284-
self._register_activities = register_activities
285-
286-
def init_client_plugin(self, next: temporalio.client.Plugin) -> None:
287-
"""Set the next client plugin"""
288-
self.next_client_plugin = next
289-
290-
async def connect_service_client(
291-
self, config: temporalio.service.ConnectConfig
292-
) -> temporalio.service.ServiceClient:
293-
"""No modifications to service client"""
294-
return await self.next_client_plugin.connect_service_client(config)
295-
296-
def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None:
297-
"""Set the next worker plugin"""
298-
self.next_worker_plugin = next
299-
300-
@staticmethod
301-
def _data_converter(converter: Optional[DataConverter]) -> DataConverter:
302-
if converter is None:
303-
return DataConverter(payload_converter_class=OpenAIPayloadConverter)
304-
elif converter.payload_converter_class is DefaultPayloadConverter:
305-
return dataclasses.replace(
306-
converter, payload_converter_class=OpenAIPayloadConverter
307-
)
308-
elif not isinstance(converter.payload_converter, OpenAIPayloadConverter):
309-
raise ValueError(
310-
"The payload converter must be of type OpenAIPayloadConverter."
311-
)
312-
return converter
313-
314-
def configure_client(self, config: ClientConfig) -> ClientConfig:
315-
"""Configure the Temporal client for OpenAI agents integration.
316-
317-
This method sets up the Pydantic data converter to enable proper
318-
serialization of OpenAI agent objects and responses.
319-
320-
Args:
321-
config: The client configuration to modify.
322-
323-
Returns:
324-
The modified client configuration.
325-
"""
326-
config["data_converter"] = self._data_converter(config["data_converter"])
327-
return self.next_client_plugin.configure_client(config)
328-
329-
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
330-
"""Configure the Temporal worker for OpenAI agents integration.
331-
332-
This method adds the necessary interceptors and activities for OpenAI
333-
agent execution:
334-
- Adds tracing interceptors for OpenAI agent interactions
335-
- Registers model execution activities
336-
337-
Args:
338-
config: The worker configuration to modify.
339-
340-
Returns:
341-
The modified worker configuration.
342-
"""
343-
config["interceptors"] = list(config.get("interceptors") or []) + [
344-
OpenAIAgentsTracingInterceptor()
345-
]
286+
# Delay activity construction until they are actually needed
287+
def add_activities(
288+
activities: Optional[Sequence[Callable]],
289+
) -> Sequence[Callable]:
290+
if not register_activities:
291+
return activities or []
346292

347-
if self._register_activities:
348-
new_activities = [ModelActivity(self._model_provider).invoke_model_activity]
293+
new_activities = [ModelActivity(model_provider).invoke_model_activity]
349294

350-
server_names = [server.name for server in self._mcp_server_providers]
295+
server_names = [server.name for server in mcp_server_providers]
351296
if len(server_names) != len(set(server_names)):
352297
raise ValueError(
353298
f"More than one mcp server registered with the same name. Please provide unique names."
354299
)
355300

356-
for mcp_server in self._mcp_server_providers:
301+
for mcp_server in mcp_server_providers:
357302
new_activities.extend(mcp_server._get_activities())
358-
config["activities"] = list(config.get("activities") or []) + new_activities
359-
360-
runner = config.get("workflow_runner")
361-
if isinstance(runner, SandboxedWorkflowRunner):
362-
config["workflow_runner"] = dataclasses.replace(
363-
runner,
364-
restrictions=runner.restrictions.with_passthrough_modules("mcp"),
365-
)
366-
367-
config["workflow_failure_exception_types"] = list(
368-
config.get("workflow_failure_exception_types") or []
369-
) + [AgentsWorkflowError]
370-
return self.next_worker_plugin.configure_worker(config)
303+
return list(activities or []) + new_activities
371304

372-
async def run_worker(self, worker: Worker) -> None:
373-
"""Run the worker with OpenAI agents temporal overrides.
305+
def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner:
306+
if not runner:
307+
raise ValueError("No WorkflowRunner provided to the OpenAI plugin.")
374308

375-
This method sets up the necessary runtime overrides for OpenAI agents
376-
to work within the Temporal worker context, including custom runners
377-
and trace providers.
378-
379-
Args:
380-
worker: The worker instance to run.
381-
"""
382-
with set_open_ai_agent_temporal_overrides(self._model_params):
383-
await self.next_worker_plugin.run_worker(worker)
384-
385-
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
386-
"""Configure the replayer for OpenAI Agents."""
387-
config["interceptors"] = list(config.get("interceptors") or []) + [
388-
OpenAIAgentsTracingInterceptor()
389-
]
390-
config["data_converter"] = self._data_converter(config.get("data_converter"))
391-
return self.next_worker_plugin.configure_replayer(config)
392-
393-
@asynccontextmanager
394-
async def run_replayer(
395-
self,
396-
replayer: Replayer,
397-
histories: AsyncIterator[temporalio.client.WorkflowHistory],
398-
) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]:
399-
"""Set the OpenAI Overrides during replay"""
400-
with set_open_ai_agent_temporal_overrides(self._model_params):
401-
async with self.next_worker_plugin.run_replayer(
402-
replayer, histories
403-
) as results:
404-
yield results
309+
# If in sandbox, add additional passthrough
310+
if isinstance(runner, SandboxedWorkflowRunner):
311+
return dataclasses.replace(
312+
runner,
313+
restrictions=runner.restrictions.with_passthrough_modules("mcp"),
314+
)
315+
return runner
316+
317+
@asynccontextmanager
318+
async def run_context() -> AsyncIterator[None]:
319+
with set_open_ai_agent_temporal_overrides(model_params):
320+
yield
321+
322+
super().__init__(
323+
name="OpenAIAgentsPlugin",
324+
data_converter=_data_converter,
325+
worker_interceptors=[OpenAIAgentsTracingInterceptor()],
326+
activities=add_activities,
327+
workflow_runner=workflow_runner,
328+
workflow_failure_exception_types=[AgentsWorkflowError],
329+
run_context=lambda: run_context(),
330+
)

0 commit comments

Comments
 (0)