Skip to content

Commit 9c3d2ae

Browse files
authored
Merge branch 'main' into fix/1188-converter-py313-isinstance
2 parents 6cc3490 + 5f24792 commit 9c3d2ae

File tree

12 files changed

+1544
-1389
lines changed

12 files changed

+1544
-1389
lines changed

temporalio/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,9 @@ class Client:
112112
Clients do not work across forks since runtimes do not work across forks.
113113
"""
114114

115-
@staticmethod
115+
@classmethod
116116
async def connect(
117+
cls,
117118
target_host: str,
118119
*,
119120
namespace: str = "default",
@@ -133,7 +134,7 @@ async def connect(
133134
runtime: Optional[temporalio.runtime.Runtime] = None,
134135
http_connect_proxy_config: Optional[HttpConnectProxyConfig] = None,
135136
header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC,
136-
) -> Client:
137+
) -> Self:
137138
"""Connect to a Temporal server.
138139
139140
Args:
@@ -209,7 +210,7 @@ def make_lambda(plugin, next):
209210

210211
service_client = await next_function(connect_config)
211212

212-
return Client(
213+
return cls(
213214
service_client,
214215
namespace=namespace,
215216
data_converter=data_converter,

temporalio/contrib/openai_agents/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@
1616
from temporalio.contrib.openai_agents._temporal_openai_agents import (
1717
OpenAIAgentsPlugin,
1818
OpenAIPayloadConverter,
19-
TestModel,
20-
TestModelProvider,
2119
)
2220
from temporalio.contrib.openai_agents._trace_interceptor import (
2321
OpenAIAgentsTracingInterceptor,
2422
)
2523
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
2624

27-
from . import workflow
25+
from . import testing, workflow
2826

2927
__all__ = [
3028
"AgentsWorkflowError",
@@ -33,7 +31,6 @@
3331
"OpenAIPayloadConverter",
3432
"StatelessMCPServerProvider",
3533
"StatefulMCPServerProvider",
36-
"TestModel",
37-
"TestModelProvider",
34+
"testing",
3835
"workflow",
3936
]

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 1 addition & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,7 @@
66
from datetime import timedelta
77
from typing import AsyncIterator, Callable, Optional, Sequence, Union
88

9-
from agents import (
10-
AgentOutputSchemaBase,
11-
Handoff,
12-
Model,
13-
ModelProvider,
14-
ModelResponse,
15-
ModelSettings,
16-
ModelTracing,
17-
Tool,
18-
TResponseInputItem,
19-
set_trace_provider,
20-
)
21-
from agents.items import TResponseStreamEvent
9+
from agents import ModelProvider, set_trace_provider
2210
from agents.run import get_default_agent_runner, set_default_agent_runner
2311
from agents.tracing import get_trace_provider
2412
from agents.tracing.provider import DefaultTraceProvider
@@ -97,58 +85,6 @@ def set_open_ai_agent_temporal_overrides(
9785
set_trace_provider(previous_trace_provider or DefaultTraceProvider())
9886

9987

100-
class TestModelProvider(ModelProvider):
101-
"""Test model provider which simply returns the given module."""
102-
103-
__test__ = False
104-
105-
def __init__(self, model: Model):
106-
"""Initialize a test model provider with a model."""
107-
self._model = model
108-
109-
def get_model(self, model_name: Union[str, None]) -> Model:
110-
"""Get a model from the model provider."""
111-
return self._model
112-
113-
114-
class TestModel(Model):
115-
"""Test model for use mocking model responses."""
116-
117-
__test__ = False
118-
119-
def __init__(self, fn: Callable[[], ModelResponse]) -> None:
120-
"""Initialize a test model with a callable."""
121-
self.fn = fn
122-
123-
async def get_response(
124-
self,
125-
system_instructions: Union[str, None],
126-
input: Union[str, list[TResponseInputItem]],
127-
model_settings: ModelSettings,
128-
tools: list[Tool],
129-
output_schema: Union[AgentOutputSchemaBase, None],
130-
handoffs: list[Handoff],
131-
tracing: ModelTracing,
132-
**kwargs,
133-
) -> ModelResponse:
134-
"""Get a response from the model."""
135-
return self.fn()
136-
137-
def stream_response(
138-
self,
139-
system_instructions: Optional[str],
140-
input: Union[str, list[TResponseInputItem]],
141-
model_settings: ModelSettings,
142-
tools: list[Tool],
143-
output_schema: Optional[AgentOutputSchemaBase],
144-
handoffs: list[Handoff],
145-
tracing: ModelTracing,
146-
**kwargs,
147-
) -> AsyncIterator[TResponseStreamEvent]:
148-
"""Get a streamed response from the model. Unimplemented."""
149-
raise NotImplementedError()
150-
151-
15288
class OpenAIPayloadConverter(PydanticPayloadConverter):
15389
"""PayloadConverter for OpenAI agents."""
15490

0 commit comments

Comments
 (0)