Skip to content

Commit c252f05

Browse files
authored
Add an optional ext module for openai (#162)
1 parent 2c61ecf commit c252f05

File tree

4 files changed

+603
-83
lines changed

4 files changed

+603
-83
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ harness = ["testcontainers", "hypercorn", "httpx"]
2727
serde = ["dacite", "pydantic", "msgspec"]
2828
client = ["httpx[http2]"]
2929
adk = ["google-adk>=1.20.0"]
30+
openai = ["openai-agents>=0.6.1"]
3031

3132
[build-system]
3233
requires = ["maturin>=1.6,<2.0"]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#
2+
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""
12+
This module contains the optional OpenAI integration for Restate.
13+
"""
14+
15+
from .runner_wrapper import Runner, DurableModelCalls, continue_on_terminal_errors, raise_terminal_errors
16+
17+
__all__ = [
18+
"DurableModelCalls",
19+
"continue_on_terminal_errors",
20+
"raise_terminal_errors",
21+
"Runner",
22+
]
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
#
2+
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""
12+
This module contains the optional OpenAI integration for Restate.
13+
"""
14+
15+
import asyncio
16+
import dataclasses
17+
import typing
18+
19+
from agents import (
20+
Tool,
21+
Usage,
22+
Model,
23+
RunContextWrapper,
24+
AgentsException,
25+
Runner as OpenAIRunner,
26+
RunConfig,
27+
TContext,
28+
RunResult,
29+
Agent,
30+
ModelBehaviorError,
31+
)
32+
33+
from agents.models.multi_provider import MultiProvider
34+
from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse
35+
from agents.memory.session import SessionABC
36+
from agents.items import TResponseInputItem
37+
from typing import List, Any
38+
from typing import AsyncIterator
39+
40+
from agents.tool import FunctionTool
41+
from agents.tool_context import ToolContext
42+
from pydantic import BaseModel
43+
from restate.exceptions import SdkInternalBaseException
44+
from restate.extensions import current_context
45+
46+
from restate import RunOptions, ObjectContext, TerminalError
47+
48+
49+
# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
50+
# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
51+
class RestateModelResponse(BaseModel):
52+
output: list[TResponseOutputItem]
53+
"""A list of outputs (messages, tool calls, etc) generated by the model"""
54+
55+
usage: Usage
56+
"""The usage information for the response."""
57+
58+
response_id: str | None
59+
"""An ID for the response which can be used to refer to the response in subsequent calls to the
60+
model. Not supported by all model providers.
61+
If using OpenAI models via the Responses API, this is the `response_id` parameter, and it can
62+
be passed to `Runner.run`.
63+
"""
64+
65+
def to_input_items(self) -> list[TResponseInputItem]:
66+
return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore
67+
68+
69+
class DurableModelCalls(MultiProvider):
70+
"""
71+
A Restate model provider that wraps the OpenAI SDK's default MultiProvider.
72+
"""
73+
74+
def __init__(self, max_retries: int | None = 3):
75+
super().__init__()
76+
self.max_retries = max_retries
77+
78+
def get_model(self, model_name: str | None) -> Model:
79+
return RestateModelWrapper(super().get_model(model_name or None), self.max_retries)
80+
81+
82+
class RestateModelWrapper(Model):
83+
"""
84+
A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
85+
"""
86+
87+
def __init__(self, model: Model, max_retries: int | None = 3):
88+
self.model = model
89+
self.model_name = "RestateModelWrapper"
90+
self.max_retries = max_retries
91+
92+
async def get_response(self, *args, **kwargs) -> ModelResponse:
93+
async def call_llm() -> RestateModelResponse:
94+
resp = await self.model.get_response(*args, **kwargs)
95+
# convert to pydantic model to be serializable by Restate SDK
96+
return RestateModelResponse(
97+
output=resp.output,
98+
usage=resp.usage,
99+
response_id=resp.response_id,
100+
)
101+
102+
ctx = current_context()
103+
if ctx is None:
104+
raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler")
105+
result = await ctx.run_typed("call LLM", call_llm, RunOptions(max_attempts=self.max_retries))
106+
# convert back to original ModelResponse
107+
return ModelResponse(
108+
output=result.output,
109+
usage=result.usage,
110+
response_id=result.response_id,
111+
)
112+
113+
def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent]:
114+
raise TerminalError("Streaming is not supported in Restate. Use `get_response` instead.")
115+
116+
117+
class RestateSession(SessionABC):
118+
"""Restate session implementation following the Session protocol."""
119+
120+
def _ctx(self) -> ObjectContext:
121+
return typing.cast(ObjectContext, current_context())
122+
123+
async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]:
124+
"""Retrieve conversation history for this session."""
125+
current_items = await self._ctx().get("items", type_hint=List[TResponseInputItem]) or []
126+
if limit is not None:
127+
return current_items[-limit:]
128+
return current_items
129+
130+
async def add_items(self, items: List[TResponseInputItem]) -> None:
131+
"""Store new items for this session."""
132+
# Your implementation here
133+
current_items = await self.get_items() or []
134+
self._ctx().set("items", current_items + items)
135+
136+
async def pop_item(self) -> TResponseInputItem | None:
137+
"""Remove and return the most recent item from this session."""
138+
current_items = await self.get_items() or []
139+
if current_items:
140+
item = current_items.pop()
141+
self._ctx().set("items", current_items)
142+
return item
143+
return None
144+
145+
async def clear_session(self) -> None:
146+
"""Clear all items for this session."""
147+
self._ctx().clear("items")
148+
149+
150+
class AgentsTerminalException(AgentsException, TerminalError):
151+
"""Exception that is both an AgentsException and a restate.TerminalError."""
152+
153+
def __init__(self, *args: object) -> None:
154+
super().__init__(*args)
155+
156+
157+
class AgentsSuspension(AgentsException, SdkInternalBaseException):
158+
"""Exception that is both an AgentsException and a restate SdkInternalBaseException."""
159+
160+
def __init__(self, *args: object) -> None:
161+
super().__init__(*args)
162+
163+
164+
def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str:
165+
"""A custom function to provide a user-friendly error message."""
166+
# Raise terminal errors and cancellations
167+
if isinstance(error, TerminalError):
168+
# For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError
169+
# so we create a new exception that inherits from both
170+
raise AgentsTerminalException(error.message)
171+
172+
if isinstance(error, ModelBehaviorError):
173+
return f"An error occurred while calling the tool: {str(error)}"
174+
175+
raise error
176+
177+
178+
def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str:
179+
"""A custom function to provide a user-friendly error message."""
180+
# Raise terminal errors and cancellations
181+
if isinstance(error, TerminalError):
182+
# For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError
183+
# so we create a new exception that inherits from both
184+
return f"An error occurred while running the tool: {str(error)}"
185+
186+
if isinstance(error, ModelBehaviorError):
187+
return f"An error occurred while calling the tool: {str(error)}"
188+
189+
raise error
190+
191+
192+
class Runner:
193+
"""
194+
A wrapper around Runner.run that automatically configures RunConfig for Restate contexts.
195+
196+
This class automatically sets up the appropriate model provider (DurableModelCalls) and
197+
model settings, taking over any model and model_settings configuration provided in the
198+
original RunConfig.
199+
"""
200+
201+
@staticmethod
202+
async def run(
203+
starting_agent: Agent[TContext],
204+
disable_tool_autowrapping: bool = False,
205+
*args: typing.Any,
206+
run_config: RunConfig | None = None,
207+
**kwargs,
208+
) -> RunResult:
209+
"""
210+
Run an agent with automatic Restate configuration.
211+
212+
Returns:
213+
The result from Runner.run
214+
"""
215+
216+
current_run_config = run_config or RunConfig()
217+
new_run_config = dataclasses.replace(
218+
current_run_config,
219+
model_provider=DurableModelCalls(),
220+
)
221+
restate_agent = sequentialize_and_wrap_tools(starting_agent, disable_tool_autowrapping)
222+
return await OpenAIRunner.run(restate_agent, *args, run_config=new_run_config, **kwargs)
223+
224+
225+
def sequentialize_and_wrap_tools(
226+
agent: Agent[TContext],
227+
disable_tool_autowrapping: bool,
228+
) -> Agent[TContext]:
229+
"""
230+
Wrap the tools of an agent to use the Restate error handling.
231+
232+
Returns:
233+
A new agent with wrapped tools.
234+
"""
235+
236+
# Restate does not allow parallel tool calls, so we use a lock to ensure sequential execution.
237+
# This lock only affects tools for this agent; handoff agents are wrapped recursively.
238+
sequential_tools_lock = asyncio.Lock()
239+
wrapped_tools: list[Tool] = []
240+
for tool in agent.tools:
241+
if isinstance(tool, FunctionTool):
242+
243+
def create_wrapper(captured_tool):
244+
async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any:
245+
await sequential_tools_lock.acquire()
246+
247+
async def invoke():
248+
result = await captured_tool.on_invoke_tool(tool_context, tool_input)
249+
# Ensure Pydantic objects are serialized to dict for LLM compatibility
250+
if hasattr(result, "model_dump"):
251+
return result.model_dump()
252+
elif hasattr(result, "dict"):
253+
return result.dict()
254+
return result
255+
256+
try:
257+
if disable_tool_autowrapping:
258+
return await invoke()
259+
260+
ctx = current_context()
261+
if ctx is None:
262+
raise RuntimeError(
263+
"No current Restate context found, make sure to run inside a Restate handler"
264+
)
265+
return await ctx.run_typed(captured_tool.name, invoke)
266+
finally:
267+
sequential_tools_lock.release()
268+
269+
return on_invoke_tool_wrapper
270+
271+
wrapped_tools.append(dataclasses.replace(tool, on_invoke_tool=create_wrapper(tool)))
272+
else:
273+
wrapped_tools.append(tool)
274+
275+
handoffs_with_wrapped_tools = []
276+
for handoff in agent.handoffs:
277+
# recursively wrap tools in handoff agents
278+
handoffs_with_wrapped_tools.append(sequentialize_and_wrap_tools(handoff, disable_tool_autowrapping)) # type: ignore
279+
280+
return agent.clone(
281+
tools=wrapped_tools,
282+
handoffs=handoffs_with_wrapped_tools,
283+
)

0 commit comments

Comments
 (0)