Skip to content

Commit 291497b

Browse files
committed
Workaround openai issue with temperature: 0 being omitted from request
This PR adds a workaround for [this GitHub discussion][gh], describing how chat completion requests which explicitly set temperature = 0 are marshalled to JSON incorrectly (due to the 0 being indistinguishable from the zero value of the field). To do so we add a custom unmarshal method which checks whether the field was defined and set to zero, and explicitly set the temperature to `math.SmallestNonzeroFloat32` in such cases. This isn't perfect but is likely to get very similar results in practice. [gh]: sashabaranov/go-openai#9 (comment)
1 parent d1c6507 commit 291497b

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

Diff for: packages/grafana-llm-app/pkg/plugin/llm_provider.go

+27
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"math"
89
"math/rand"
910
"strings"
1011

@@ -67,6 +68,32 @@ type ChatCompletionRequest struct {
6768
Model Model `json:"model"`
6869
}
6970

71+
// UnmarshalJSON implements json.Unmarshaler.
72+
// We have a custom implementation here to check whether temperature is being
73+
// explicitly set to `0` in the incoming request, because the `openai.ChatCompletionRequest`
74+
// struct has `omitempty` on the Temperature field and would omit it when marshaling.
75+
// If there is an explicit 0 value in the request, we set it to `math.SmallestNonzeroFloat32`,
76+
// a workaround mentioned in https://github.com/sashabaranov/go-openai/issues/9#issuecomment-894845206.
77+
func (c *ChatCompletionRequest) UnmarshalJSON(data []byte) error {
78+
// Create a wrapper type alias to avoid recursion, otherwise the
79+
// subsequent call to UnmarshalJSON would call this method forever.
80+
type Alias ChatCompletionRequest
81+
var a Alias
82+
if err := json.Unmarshal(data, &a); err != nil {
83+
return err
84+
}
85+
// Also unmarshal to a map to check if temperature is being set explicitly in the request.
86+
r := map[string]any{}
87+
if err := json.Unmarshal(data, &r); err != nil {
88+
return err
89+
}
90+
if t, ok := r["temperature"].(float64); ok && t == 0 {
91+
a.ChatCompletionRequest.Temperature = math.SmallestNonzeroFloat32
92+
}
93+
*c = ChatCompletionRequest(a)
94+
return nil
95+
}
96+
7097
type ChatCompletionStreamResponse struct {
7198
openai.ChatCompletionStreamResponse
7299
// Random padding used to mitigate side channel attacks.

Diff for: packages/grafana-llm-app/pkg/plugin/llm_provider_test.go

+45-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package plugin
22

33
import (
44
"encoding/json"
5+
"math"
56
"testing"
67

78
"github.com/sashabaranov/go-openai"
9+
"github.com/stretchr/testify/assert"
810
)
911

1012
func TestModelFromString(t *testing.T) {
@@ -83,7 +85,7 @@ func TestModelFromString(t *testing.T) {
8385
}
8486
}
8587

86-
func TestUnmarshalJSON(t *testing.T) {
88+
func TestModelUnmarshalJSON(t *testing.T) {
8789
tests := []struct {
8890
input []byte
8991
expected Model
@@ -164,6 +166,48 @@ func TestUnmarshalJSON(t *testing.T) {
164166
}
165167
}
166168

169+
func TestChatCompletionRequestUnmarshalJSON(t *testing.T) {
170+
for _, tt := range []struct {
171+
input []byte
172+
expected ChatCompletionRequest
173+
}{
174+
{
175+
input: []byte(`{"model":"base"}`),
176+
expected: ChatCompletionRequest{
177+
Model: ModelBase,
178+
ChatCompletionRequest: openai.ChatCompletionRequest{
179+
Temperature: 0,
180+
},
181+
},
182+
},
183+
{
184+
input: []byte(`{"model":"base", "temperature":0.5}`),
185+
expected: ChatCompletionRequest{
186+
Model: ModelBase,
187+
ChatCompletionRequest: openai.ChatCompletionRequest{
188+
Temperature: 0.5,
189+
},
190+
},
191+
},
192+
{
193+
input: []byte(`{"model":"base", "temperature":0}`),
194+
expected: ChatCompletionRequest{
195+
Model: ModelBase,
196+
ChatCompletionRequest: openai.ChatCompletionRequest{
197+
Temperature: math.SmallestNonzeroFloat32,
198+
},
199+
},
200+
},
201+
} {
202+
t.Run(string(tt.input), func(t *testing.T) {
203+
var req ChatCompletionRequest
204+
err := json.Unmarshal(tt.input, &req)
205+
assert.NoError(t, err)
206+
assert.Equal(t, tt.expected, req)
207+
})
208+
}
209+
}
210+
167211
func TestChatCompletionStreamResponseMarshalJSON(t *testing.T) {
168212
resp := ChatCompletionStreamResponse{
169213
ChatCompletionStreamResponse: openai.ChatCompletionStreamResponse{

0 commit comments

Comments
 (0)