Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions temporalio/contrib/openai_agents/_temporal_openai_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(
mcp_server_providers: Sequence[
Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"]
] = (),
register_activities: bool = True,
) -> None:
"""Initialize the OpenAI agents plugin.

Expand All @@ -257,6 +258,9 @@ def __init__(
Each server will be wrapped in a TemporalMCPServer if not already wrapped,
and their activities will be automatically registered with the worker.
The plugin manages the connection lifecycle of these servers.
register_activities: Whether to register activities during the worker execution.
This can be disabled on some workers to allow a separation of workflows and activities
but should not be disabled on all workers, or agents will not be able to progress.
"""
if model_params is None:
model_params = ModelActivityParameters()
Expand All @@ -277,6 +281,7 @@ def __init__(
self._model_params = model_params
self._model_provider = model_provider
self._mcp_server_providers = mcp_server_providers
self._register_activities = register_activities

def init_client_plugin(self, next: temporalio.client.Plugin) -> None:
"""Set the next client plugin"""
Expand Down Expand Up @@ -338,17 +343,19 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
config["interceptors"] = list(config.get("interceptors") or []) + [
OpenAIAgentsTracingInterceptor()
]
new_activities = [ModelActivity(self._model_provider).invoke_model_activity]

server_names = [server.name for server in self._mcp_server_providers]
if len(server_names) != len(set(server_names)):
raise ValueError(
f"More than one mcp server registered with the same name. Please provide unique names."
)
if self._register_activities:
new_activities = [ModelActivity(self._model_provider).invoke_model_activity]

server_names = [server.name for server in self._mcp_server_providers]
if len(server_names) != len(set(server_names)):
raise ValueError(
f"More than one mcp server registered with the same name. Please provide unique names."
)

for mcp_server in self._mcp_server_providers:
new_activities.extend(mcp_server._get_activities())
config["activities"] = list(config.get("activities") or []) + new_activities
for mcp_server in self._mcp_server_providers:
new_activities.extend(mcp_server._get_activities())
config["activities"] = list(config.get("activities") or []) + new_activities

runner = config.get("workflow_runner")
if isinstance(runner, SandboxedWorkflowRunner):
Expand Down