|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from dataclasses import dataclass |
| 3 | +from dataclasses import dataclass, fields, replace |
4 | 4 | from typing import Literal
|
5 | 5 |
|
6 | 6 |
|
@@ -30,27 +30,29 @@ class ModelSettings:
|
30 | 30 | tool_choice: Literal["auto", "required", "none"] | str | None = None
|
31 | 31 | """The tool choice to use when calling the model."""
|
32 | 32 |
|
33 |
| - parallel_tool_calls: bool | None = False |
34 |
| - """Whether to use parallel tool calls when calling the model.""" |
| 33 | + parallel_tool_calls: bool | None = None |
| 34 | + """Whether to use parallel tool calls when calling the model. |
| 35 | + Defaults to False if not provided.""" |
35 | 36 |
|
36 | 37 | truncation: Literal["auto", "disabled"] | None = None
|
37 | 38 | """The truncation strategy to use when calling the model."""
|
38 | 39 |
|
39 | 40 | max_tokens: int | None = None
|
40 | 41 | """The maximum number of output tokens to generate."""
|
41 | 42 |
|
| 43 | + store: bool | None = None |
| 44 | + """Whether to store the generated model response for later retrieval. |
| 45 | + Defaults to True if not provided.""" |
| 46 | + |
42 | 47 | def resolve(self, override: ModelSettings | None) -> ModelSettings:
|
43 | 48 | """Produce a new ModelSettings by overlaying any non-None values from the
|
44 | 49 | override on top of this instance."""
|
45 | 50 | if override is None:
|
46 | 51 | return self
|
47 |
| - return ModelSettings( |
48 |
| - temperature=override.temperature or self.temperature, |
49 |
| - top_p=override.top_p or self.top_p, |
50 |
| - frequency_penalty=override.frequency_penalty or self.frequency_penalty, |
51 |
| - presence_penalty=override.presence_penalty or self.presence_penalty, |
52 |
| - tool_choice=override.tool_choice or self.tool_choice, |
53 |
| - parallel_tool_calls=override.parallel_tool_calls or self.parallel_tool_calls, |
54 |
| - truncation=override.truncation or self.truncation, |
55 |
| - max_tokens=override.max_tokens or self.max_tokens, |
56 |
| - ) |
| 52 | + |
| 53 | + changes = { |
| 54 | + field.name: getattr(override, field.name) |
| 55 | + for field in fields(self) |
| 56 | + if getattr(override, field.name) is not None |
| 57 | + } |
| 58 | + return replace(self, **changes) |
0 commit comments