|
22 | 22 | from agents.run import get_default_agent_runner, set_default_agent_runner |
23 | 23 | from agents.tracing import get_trace_provider |
24 | 24 | from agents.tracing.provider import DefaultTraceProvider |
25 | | -from openai.types.responses import ResponsePromptParam |
26 | 25 |
|
27 | | -import temporalio.client |
28 | | -import temporalio.worker |
29 | | -from temporalio.client import ClientConfig |
30 | 26 | from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity |
31 | 27 | from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters |
32 | 28 | from temporalio.contrib.openai_agents._openai_runner import ( |
|
47 | 43 | DataConverter, |
48 | 44 | DefaultPayloadConverter, |
49 | 45 | ) |
50 | | -from temporalio.worker import ( |
51 | | - Replayer, |
52 | | - ReplayerConfig, |
53 | | - Worker, |
54 | | - WorkerConfig, |
55 | | - WorkflowReplayResult, |
56 | | -) |
| 46 | +from temporalio.plugin import SimplePlugin |
| 47 | +from temporalio.worker import WorkflowRunner |
57 | 48 | from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner |
58 | 49 |
|
59 | 50 | # Unsupported on python 3.9 |
@@ -172,7 +163,21 @@ def __init__(self) -> None: |
172 | 163 | super().__init__(ToJsonOptions(exclude_unset=True)) |
173 | 164 |
|
174 | 165 |
|
175 | | -class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): |
| 166 | +def _data_converter(converter: Optional[DataConverter]) -> DataConverter: |
| 167 | + if converter is None: |
| 168 | + return DataConverter(payload_converter_class=OpenAIPayloadConverter) |
| 169 | + elif converter.payload_converter_class is DefaultPayloadConverter: |
| 170 | + return dataclasses.replace( |
| 171 | + converter, payload_converter_class=OpenAIPayloadConverter |
| 172 | + ) |
| 173 | + elif not isinstance(converter.payload_converter, OpenAIPayloadConverter): |
| 174 | + raise ValueError( |
| 175 | + "The payload converter must be of type OpenAIPayloadConverter." |
| 176 | + ) |
| 177 | + return converter |
| 178 | + |
| 179 | + |
| 180 | +class OpenAIAgentsPlugin(SimplePlugin): |
176 | 181 | """Temporal plugin for integrating OpenAI agents with Temporal workflows. |
177 | 182 |
|
178 | 183 | .. warning:: |
@@ -278,127 +283,48 @@ def __init__( |
278 | 283 | "When configuring a custom provider, the model activity must have start_to_close_timeout or schedule_to_close_timeout" |
279 | 284 | ) |
280 | 285 |
|
281 | | - self._model_params = model_params |
282 | | - self._model_provider = model_provider |
283 | | - self._mcp_server_providers = mcp_server_providers |
284 | | - self._register_activities = register_activities |
285 | | - |
286 | | - def init_client_plugin(self, next: temporalio.client.Plugin) -> None: |
287 | | - """Set the next client plugin""" |
288 | | - self.next_client_plugin = next |
289 | | - |
290 | | - async def connect_service_client( |
291 | | - self, config: temporalio.service.ConnectConfig |
292 | | - ) -> temporalio.service.ServiceClient: |
293 | | - """No modifications to service client""" |
294 | | - return await self.next_client_plugin.connect_service_client(config) |
295 | | - |
296 | | - def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: |
297 | | - """Set the next worker plugin""" |
298 | | - self.next_worker_plugin = next |
299 | | - |
300 | | - @staticmethod |
301 | | - def _data_converter(converter: Optional[DataConverter]) -> DataConverter: |
302 | | - if converter is None: |
303 | | - return DataConverter(payload_converter_class=OpenAIPayloadConverter) |
304 | | - elif converter.payload_converter_class is DefaultPayloadConverter: |
305 | | - return dataclasses.replace( |
306 | | - converter, payload_converter_class=OpenAIPayloadConverter |
307 | | - ) |
308 | | - elif not isinstance(converter.payload_converter, OpenAIPayloadConverter): |
309 | | - raise ValueError( |
310 | | - "The payload converter must be of type OpenAIPayloadConverter." |
311 | | - ) |
312 | | - return converter |
313 | | - |
314 | | - def configure_client(self, config: ClientConfig) -> ClientConfig: |
315 | | - """Configure the Temporal client for OpenAI agents integration. |
316 | | -
|
317 | | - This method sets up the Pydantic data converter to enable proper |
318 | | - serialization of OpenAI agent objects and responses. |
319 | | -
|
320 | | - Args: |
321 | | - config: The client configuration to modify. |
322 | | -
|
323 | | - Returns: |
324 | | - The modified client configuration. |
325 | | - """ |
326 | | - config["data_converter"] = self._data_converter(config["data_converter"]) |
327 | | - return self.next_client_plugin.configure_client(config) |
328 | | - |
329 | | - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: |
330 | | - """Configure the Temporal worker for OpenAI agents integration. |
331 | | -
|
332 | | - This method adds the necessary interceptors and activities for OpenAI |
333 | | - agent execution: |
334 | | - - Adds tracing interceptors for OpenAI agent interactions |
335 | | - - Registers model execution activities |
336 | | -
|
337 | | - Args: |
338 | | - config: The worker configuration to modify. |
339 | | -
|
340 | | - Returns: |
341 | | - The modified worker configuration. |
342 | | - """ |
343 | | - config["interceptors"] = list(config.get("interceptors") or []) + [ |
344 | | - OpenAIAgentsTracingInterceptor() |
345 | | - ] |
| 286 | + # Delay activity construction until they are actually needed |
| 287 | + def add_activities( |
| 288 | + activities: Optional[Sequence[Callable]], |
| 289 | + ) -> Sequence[Callable]: |
| 290 | + if not register_activities: |
| 291 | + return activities or [] |
346 | 292 |
|
347 | | - if self._register_activities: |
348 | | - new_activities = [ModelActivity(self._model_provider).invoke_model_activity] |
| 293 | + new_activities = [ModelActivity(model_provider).invoke_model_activity] |
349 | 294 |
|
350 | | - server_names = [server.name for server in self._mcp_server_providers] |
| 295 | + server_names = [server.name for server in mcp_server_providers] |
351 | 296 | if len(server_names) != len(set(server_names)): |
352 | 297 | raise ValueError( |
353 | 298 | f"More than one mcp server registered with the same name. Please provide unique names." |
354 | 299 | ) |
355 | 300 |
|
356 | | - for mcp_server in self._mcp_server_providers: |
| 301 | + for mcp_server in mcp_server_providers: |
357 | 302 | new_activities.extend(mcp_server._get_activities()) |
358 | | - config["activities"] = list(config.get("activities") or []) + new_activities |
359 | | - |
360 | | - runner = config.get("workflow_runner") |
361 | | - if isinstance(runner, SandboxedWorkflowRunner): |
362 | | - config["workflow_runner"] = dataclasses.replace( |
363 | | - runner, |
364 | | - restrictions=runner.restrictions.with_passthrough_modules("mcp"), |
365 | | - ) |
366 | | - |
367 | | - config["workflow_failure_exception_types"] = list( |
368 | | - config.get("workflow_failure_exception_types") or [] |
369 | | - ) + [AgentsWorkflowError] |
370 | | - return self.next_worker_plugin.configure_worker(config) |
| 303 | + return list(activities or []) + new_activities |
371 | 304 |
|
372 | | - async def run_worker(self, worker: Worker) -> None: |
373 | | - """Run the worker with OpenAI agents temporal overrides. |
| 305 | + def workflow_runner(runner: Optional[WorkflowRunner]) -> WorkflowRunner: |
| 306 | + if not runner: |
| 307 | + raise ValueError("No WorkflowRunner provided to the OpenAI plugin.") |
374 | 308 |
|
375 | | - This method sets up the necessary runtime overrides for OpenAI agents |
376 | | - to work within the Temporal worker context, including custom runners |
377 | | - and trace providers. |
378 | | -
|
379 | | - Args: |
380 | | - worker: The worker instance to run. |
381 | | - """ |
382 | | - with set_open_ai_agent_temporal_overrides(self._model_params): |
383 | | - await self.next_worker_plugin.run_worker(worker) |
384 | | - |
385 | | - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: |
386 | | - """Configure the replayer for OpenAI Agents.""" |
387 | | - config["interceptors"] = list(config.get("interceptors") or []) + [ |
388 | | - OpenAIAgentsTracingInterceptor() |
389 | | - ] |
390 | | - config["data_converter"] = self._data_converter(config.get("data_converter")) |
391 | | - return self.next_worker_plugin.configure_replayer(config) |
392 | | - |
393 | | - @asynccontextmanager |
394 | | - async def run_replayer( |
395 | | - self, |
396 | | - replayer: Replayer, |
397 | | - histories: AsyncIterator[temporalio.client.WorkflowHistory], |
398 | | - ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: |
399 | | - """Set the OpenAI Overrides during replay""" |
400 | | - with set_open_ai_agent_temporal_overrides(self._model_params): |
401 | | - async with self.next_worker_plugin.run_replayer( |
402 | | - replayer, histories |
403 | | - ) as results: |
404 | | - yield results |
| 309 | + # If in sandbox, add additional passthrough |
| 310 | + if isinstance(runner, SandboxedWorkflowRunner): |
| 311 | + return dataclasses.replace( |
| 312 | + runner, |
| 313 | + restrictions=runner.restrictions.with_passthrough_modules("mcp"), |
| 314 | + ) |
| 315 | + return runner |
| 316 | + |
| 317 | + @asynccontextmanager |
| 318 | + async def run_context() -> AsyncIterator[None]: |
| 319 | + with set_open_ai_agent_temporal_overrides(model_params): |
| 320 | + yield |
| 321 | + |
| 322 | + super().__init__( |
| 323 | + name="OpenAIAgentsPlugin", |
| 324 | + data_converter=_data_converter, |
| 325 | + worker_interceptors=[OpenAIAgentsTracingInterceptor()], |
| 326 | + activities=add_activities, |
| 327 | + workflow_runner=workflow_runner, |
| 328 | + workflow_failure_exception_types=[AgentsWorkflowError], |
| 329 | + run_context=lambda: run_context(), |
| 330 | + ) |
0 commit comments