Skip to content

Commit a924846

Browse files
committed
feat: chat completion message add extra fields
1 parent 07791be commit a924846

File tree

5 files changed

+159
-30
lines changed

5 files changed

+159
-30
lines changed

chat.go

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"io"
99
"net/http"
10+
"reflect"
1011

1112
openai "github.com/meguminnnnnnnnn/go-openai/internal"
1213

@@ -136,6 +137,8 @@ type ChatCompletionMessage struct {
136137

137138
// For Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool.
138139
ToolCallID string `json:"tool_call_id,omitempty"`
140+
141+
ExtraFields map[string]json.RawMessage `json:"-"`
139142
}
140143

141144
func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
@@ -144,29 +147,31 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
144147
}
145148
if len(m.MultiContent) > 0 {
146149
msg := struct {
147-
Role string `json:"role"`
148-
Content string `json:"-"`
149-
Refusal string `json:"refusal,omitempty"`
150-
MultiContent []ChatMessagePart `json:"content,omitempty"`
151-
Name string `json:"name,omitempty"`
152-
ReasoningContent string `json:"reasoning_content,omitempty"`
153-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
154-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
155-
ToolCallID string `json:"tool_call_id,omitempty"`
150+
Role string `json:"role"`
151+
Content string `json:"-"`
152+
Refusal string `json:"refusal,omitempty"`
153+
MultiContent []ChatMessagePart `json:"content,omitempty"`
154+
Name string `json:"name,omitempty"`
155+
ReasoningContent string `json:"reasoning_content,omitempty"`
156+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
157+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
158+
ToolCallID string `json:"tool_call_id,omitempty"`
159+
ExtraFields map[string]json.RawMessage `json:"-"`
156160
}(m)
157161
return json.Marshal(msg)
158162
}
159163

160164
msg := struct {
161-
Role string `json:"role"`
162-
Content string `json:"content,omitempty"`
163-
Refusal string `json:"refusal,omitempty"`
164-
MultiContent []ChatMessagePart `json:"-"`
165-
Name string `json:"name,omitempty"`
166-
ReasoningContent string `json:"reasoning_content,omitempty"`
167-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
168-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
169-
ToolCallID string `json:"tool_call_id,omitempty"`
165+
Role string `json:"role"`
166+
Content string `json:"content,omitempty"`
167+
Refusal string `json:"refusal,omitempty"`
168+
MultiContent []ChatMessagePart `json:"-"`
169+
Name string `json:"name,omitempty"`
170+
ReasoningContent string `json:"reasoning_content,omitempty"`
171+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
172+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
173+
ToolCallID string `json:"tool_call_id,omitempty"`
174+
ExtraFields map[string]json.RawMessage `json:"-"`
170175
}(m)
171176
return json.Marshal(msg)
172177
}
@@ -177,32 +182,49 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
177182
Content string `json:"content"`
178183
Refusal string `json:"refusal,omitempty"`
179184
MultiContent []ChatMessagePart
180-
Name string `json:"name,omitempty"`
181-
ReasoningContent string `json:"reasoning_content,omitempty"`
182-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
183-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
184-
ToolCallID string `json:"tool_call_id,omitempty"`
185+
Name string `json:"name,omitempty"`
186+
ReasoningContent string `json:"reasoning_content,omitempty"`
187+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
188+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
189+
ToolCallID string `json:"tool_call_id,omitempty"`
190+
ExtraFields map[string]json.RawMessage `json:"-"`
185191
}{}
186192

187193
if err := json.Unmarshal(bs, &msg); err == nil {
188194
*m = ChatCompletionMessage(msg)
195+
var extra map[string]json.RawMessage
196+
extra, err = openai.UnmarshalExtraFields(reflect.TypeOf(m), bs)
197+
if err != nil {
198+
return err
199+
}
200+
201+
m.ExtraFields = extra
189202
return nil
190203
}
204+
191205
multiMsg := struct {
192206
Role string `json:"role"`
193207
Content string
194-
Refusal string `json:"refusal,omitempty"`
195-
MultiContent []ChatMessagePart `json:"content"`
196-
Name string `json:"name,omitempty"`
197-
ReasoningContent string `json:"reasoning_content,omitempty"`
198-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
199-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
200-
ToolCallID string `json:"tool_call_id,omitempty"`
208+
Refusal string `json:"refusal,omitempty"`
209+
MultiContent []ChatMessagePart `json:"content"`
210+
Name string `json:"name,omitempty"`
211+
ReasoningContent string `json:"reasoning_content,omitempty"`
212+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
213+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
214+
ToolCallID string `json:"tool_call_id,omitempty"`
215+
ExtraFields map[string]json.RawMessage `json:"-"`
201216
}{}
202217
if err := json.Unmarshal(bs, &multiMsg); err != nil {
203218
return err
204219
}
205220
*m = ChatCompletionMessage(multiMsg)
221+
222+
extra, err := openai.UnmarshalExtraFields(reflect.TypeOf(m), bs)
223+
if err != nil {
224+
return err
225+
}
226+
227+
m.ExtraFields = extra
206228
return nil
207229
}
208230

chat_stream.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@ package openai
22

33
import (
44
"context"
5+
"encoding/json"
56
"io"
67
"net/http"
8+
"reflect"
9+
10+
openai "github.com/meguminnnnnnnnn/go-openai/internal"
711
)
812

913
type ChatCompletionStreamChoiceDelta struct {
@@ -18,6 +22,35 @@ type ChatCompletionStreamChoiceDelta struct {
1822
// the doc from deepseek:
1923
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
2024
ReasoningContent string `json:"reasoning_content,omitempty"`
25+
26+
ExtraFields map[string]json.RawMessage `json:"-"`
27+
}
28+
29+
func (c *ChatCompletionStreamChoiceDelta) UnmarshalJSON(bs []byte) error {
30+
msg := struct {
31+
Content string `json:"content,omitempty"`
32+
Role string `json:"role,omitempty"`
33+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
34+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
35+
Refusal string `json:"refusal,omitempty"`
36+
ReasoningContent string `json:"reasoning_content,omitempty"`
37+
38+
ExtraFields map[string]json.RawMessage `json:"-"`
39+
}{}
40+
err := json.Unmarshal(bs, &msg)
41+
if err != nil {
42+
return err
43+
}
44+
45+
*c = msg
46+
var extra map[string]json.RawMessage
47+
extra, err = openai.UnmarshalExtraFields(reflect.TypeOf(c), bs)
48+
if err != nil {
49+
return err
50+
}
51+
52+
c.ExtraFields = extra
53+
return nil
2154
}
2255

2356
type ChatCompletionStreamChoiceLogprobs struct {

chat_stream_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/meguminnnnnnnnn/go-openai"
1414
"github.com/meguminnnnnnnnn/go-openai/internal/test/checks"
15+
"github.com/stretchr/testify/assert"
1516
)
1617

1718
func TestChatCompletionsStreamWrongModel(t *testing.T) {
@@ -1021,3 +1022,34 @@ func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice)
10211022
}
10221023
return true
10231024
}
1025+
1026+
func TestChatCompletionStreamChoiceDelta_UnmarshalJSON(t *testing.T) {
1027+
bs := []byte(`{
1028+
"content": "Hello!",
1029+
"role": "user",
1030+
"multimodal_content": {
1031+
"type": "inline_data",
1032+
"inline_data": {
1033+
"mime_type": "image/png",
1034+
"data": "iVB"
1035+
}
1036+
}
1037+
}
1038+
`)
1039+
1040+
delta := openai.ChatCompletionStreamChoiceDelta{}
1041+
err := json.Unmarshal(bs, &delta)
1042+
assert.NoError(t, err)
1043+
multimodalContent, ok := delta.ExtraFields["multimodal_content"]
1044+
assert.True(t, ok)
1045+
content := map[string]any{}
1046+
err = json.Unmarshal(multimodalContent, &content)
1047+
assert.NoError(t, err)
1048+
assert.Equal(t, map[string]any{
1049+
"type": "inline_data",
1050+
"inline_data": map[string]interface{}{
1051+
"mime_type": "image/png",
1052+
"data": "iVB",
1053+
},
1054+
}, content)
1055+
}

chat_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/meguminnnnnnnnn/go-openai"
1616
"github.com/meguminnnnnnnnn/go-openai/internal/test/checks"
1717
"github.com/meguminnnnnnnnn/go-openai/jsonschema"
18+
"github.com/stretchr/testify/assert"
1819
)
1920

2021
const (
@@ -1160,3 +1161,42 @@ func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) {
11601161
})
11611162
}
11621163
}
1164+
1165+
func TestChatCompletionMessage_UnmarshalJSON(t *testing.T) {
1166+
bs := []byte(`{
1167+
"role": "system",
1168+
"content": "You are a helpful math tutor.",
1169+
"name": "name",
1170+
"multimodal_contents": [
1171+
{
1172+
"type": "text",
1173+
"text": "ok"
1174+
},
1175+
{
1176+
"type": "text",
1177+
"text": "Generate a picture of a Shiba Inu dog for you。"
1178+
},
1179+
{
1180+
"type": "inline_data",
1181+
"inline_data": {
1182+
"mime_type": "image/png",
1183+
"data": "iVBI"
1184+
}
1185+
}
1186+
]
1187+
}`)
1188+
chatMessage := &openai.ChatCompletionMessage{}
1189+
err := json.Unmarshal(bs, chatMessage)
1190+
assert.Nil(t, err)
1191+
1192+
multimodalContent := chatMessage.ExtraFields["multimodal_contents"]
1193+
mContents := make([]map[string]any, 0)
1194+
err = json.Unmarshal(multimodalContent, &mContents)
1195+
assert.Nil(t, err)
1196+
1197+
assert.Equal(t, mContents, []map[string]any{
1198+
{"type": "text", "text": "ok"},
1199+
{"type": "text", "text": "Generate a picture of a Shiba Inu dog for you。"},
1200+
{"type": "inline_data", "inline_data": map[string]any{"mime_type": "image/png", "data": "iVBI"}},
1201+
})
1202+
}

internal/unmarshaler.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"fmt"
66
"reflect"
7+
"strings"
78

89
"github.com/bytedance/sonic"
910
)
@@ -37,6 +38,7 @@ func UnmarshalExtraFields(typ reflect.Type, data []byte) (map[string]json.RawMes
3738

3839
jsonTag := field.Tag.Get("json")
3940
if jsonTag != "" {
41+
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
4042
delete(m, jsonTag)
4143
} else {
4244
if !field.IsExported() {

0 commit comments

Comments
 (0)