Skip to content

Commit aad8acc

Browse files
authored
Expose the "store" parameter through ModelSettings (#357)
Closes #173 This will also set stored completions to True by default, encouraging a best practice.
1 parent 4854a74 commit aad8acc

File tree

5 files changed

+52
-13
lines changed

5 files changed

+52
-13
lines changed

src/agents/model_settings.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, fields, replace
44
from typing import Literal
55

66

@@ -30,27 +30,29 @@ class ModelSettings:
3030
tool_choice: Literal["auto", "required", "none"] | str | None = None
3131
"""The tool choice to use when calling the model."""
3232

33-
parallel_tool_calls: bool | None = False
34-
"""Whether to use parallel tool calls when calling the model."""
33+
parallel_tool_calls: bool | None = None
34+
"""Whether to use parallel tool calls when calling the model.
35+
Defaults to False if not provided."""
3536

3637
truncation: Literal["auto", "disabled"] | None = None
3738
"""The truncation strategy to use when calling the model."""
3839

3940
max_tokens: int | None = None
4041
"""The maximum number of output tokens to generate."""
4142

43+
store: bool | None = None
44+
"""Whether to store the generated model response for later retrieval.
45+
Defaults to True if not provided."""
46+
4247
def resolve(self, override: ModelSettings | None) -> ModelSettings:
4348
"""Produce a new ModelSettings by overlaying any non-None values from the
4449
override on top of this instance."""
4550
if override is None:
4651
return self
47-
return ModelSettings(
48-
temperature=override.temperature or self.temperature,
49-
top_p=override.top_p or self.top_p,
50-
frequency_penalty=override.frequency_penalty or self.frequency_penalty,
51-
presence_penalty=override.presence_penalty or self.presence_penalty,
52-
tool_choice=override.tool_choice or self.tool_choice,
53-
parallel_tool_calls=override.parallel_tool_calls or self.parallel_tool_calls,
54-
truncation=override.truncation or self.truncation,
55-
max_tokens=override.max_tokens or self.max_tokens,
56-
)
52+
53+
changes = {
54+
field.name: getattr(override, field.name)
55+
for field in fields(self)
56+
if getattr(override, field.name) is not None
57+
}
58+
return replace(self, **changes)

src/agents/models/openai_chatcompletions.py

+4
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,9 @@ async def _fetch_response(
518518
f"Response format: {response_format}\n"
519519
)
520520

521+
# Match the behavior of Responses where store is True when not given
522+
store = model_settings.store if model_settings.store is not None else True
523+
521524
ret = await self._get_client().chat.completions.create(
522525
model=self.model,
523526
messages=converted_messages,
@@ -532,6 +535,7 @@ async def _fetch_response(
532535
parallel_tool_calls=parallel_tool_calls,
533536
stream=stream,
534537
stream_options={"include_usage": True} if stream else NOT_GIVEN,
538+
store=store,
535539
extra_headers=_HEADERS,
536540
)
537541

src/agents/models/openai_responses.py

+1
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ async def _fetch_response(
246246
stream=stream,
247247
extra_headers=_HEADERS,
248248
text=response_format,
249+
store=self._non_null_or_not_given(model_settings.store),
249250
)
250251

251252
def _get_client(self) -> AsyncOpenAI:

tests/test_agent_runner.py

+30
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
InputGuardrail,
1515
InputGuardrailTripwireTriggered,
1616
ModelBehaviorError,
17+
ModelSettings,
1718
OutputGuardrail,
1819
OutputGuardrailTripwireTriggered,
20+
RunConfig,
1921
RunContextWrapper,
2022
Runner,
2123
UserError,
@@ -634,3 +636,31 @@ async def test_tool_use_behavior_custom_function():
634636

635637
assert len(result.raw_responses) == 2, "should have two model responses"
636638
assert result.final_output == "the_final_output", "should have used the custom function"
639+
640+
641+
@pytest.mark.asyncio
642+
async def test_model_settings_override():
643+
model = FakeModel()
644+
agent = Agent(
645+
name="test",
646+
model=model,
647+
model_settings=ModelSettings(temperature=1.0, max_tokens=1000)
648+
)
649+
650+
model.add_multiple_turn_outputs(
651+
[
652+
[
653+
get_text_message("a_message"),
654+
],
655+
]
656+
)
657+
658+
await Runner.run(
659+
agent,
660+
input="user_message",
661+
run_config=RunConfig(model_settings=ModelSettings(0.5)),
662+
)
663+
664+
# temperature is overridden by Runner.run, but max_tokens is not
665+
assert model.last_turn_args["model_settings"].temperature == 0.5
666+
assert model.last_turn_args["model_settings"].max_tokens == 1000

tests/test_openai_chatcompletions.py

+2
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def __init__(self, completions: DummyCompletions) -> None:
226226
# Ensure expected args were passed through to OpenAI client.
227227
kwargs = completions.kwargs
228228
assert kwargs["stream"] is False
229+
assert kwargs["store"] is True
229230
assert kwargs["model"] == "gpt-4"
230231
assert kwargs["messages"][0]["role"] == "system"
231232
assert kwargs["messages"][0]["content"] == "sys"
@@ -279,6 +280,7 @@ def __init__(self, completions: DummyCompletions) -> None:
279280
)
280281
# Check OpenAI client was called for streaming
281282
assert completions.kwargs["stream"] is True
283+
assert completions.kwargs["store"] is True
282284
assert completions.kwargs["stream_options"] == {"include_usage": True}
283285
# Response is a proper openai Response
284286
assert isinstance(response, Response)

0 commit comments

Comments
 (0)