Skip to content

Commit 5fd9357

Browse files
committed
support multi modal inputs
1 parent ea15bdc commit 5fd9357

File tree

3 files changed

+136
-24
lines changed

3 files changed

+136
-24
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ func TestPrefixPluginChatCompletions(t *testing.T) {
217217
Body: &types.LLMRequestBody{
218218
ChatCompletions: &types.ChatCompletionsRequest{
219219
Messages: []types.Message{
220-
{Role: "user", Content: "hello world"},
221-
{Role: "assistant", Content: "hi there"},
220+
{Role: "user", Content: types.Content{Raw: "hello world"}},
221+
{Role: "assistant", Content: types.Content{Raw: "hi there"}},
222222
},
223223
},
224224
},
@@ -252,8 +252,8 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
252252
Body: &types.LLMRequestBody{
253253
ChatCompletions: &types.ChatCompletionsRequest{
254254
Messages: []types.Message{
255-
{Role: "system", Content: "You are a helpful assistant"},
256-
{Role: "user", Content: "Hello, how are you?"},
255+
{Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
256+
{Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
257257
},
258258
},
259259
},
@@ -285,10 +285,10 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
285285
Body: &types.LLMRequestBody{
286286
ChatCompletions: &types.ChatCompletionsRequest{
287287
Messages: []types.Message{
288-
{Role: "system", Content: "You are a helpful assistant"},
289-
{Role: "user", Content: "Hello, how are you?"},
290-
{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"},
291-
{Role: "user", Content: "Can you explain how prefix caching works?"},
288+
{Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
289+
{Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
290+
{Role: "assistant", Content: types.Content{Raw: "I'm doing well, thank you! How can I help you today?"}},
291+
{Role: "user", Content: types.Content{Raw: "Can you explain how prefix caching works?"}},
292292
},
293293
},
294294
},
@@ -318,12 +318,12 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
318318
Body: &types.LLMRequestBody{
319319
ChatCompletions: &types.ChatCompletionsRequest{
320320
Messages: []types.Message{
321-
{Role: "system", Content: "You are a helpful assistant"},
322-
{Role: "user", Content: "Hello, how are you?"},
323-
{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"},
324-
{Role: "user", Content: "Can you explain how prefix caching works?"},
325-
{Role: "assistant", Content: "Prefix caching is a technique where..."},
326-
{Role: "user", Content: "That's very helpful, thank you!"},
321+
{Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
322+
{Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
323+
{Role: "assistant", Content: types.Content{Raw: "I'm doing well, thank you! How can I help you today?"}},
324+
{Role: "user", Content: types.Content{Raw: "Can you explain how prefix caching works?"}},
325+
{Role: "assistant", Content: types.Content{Raw: "Prefix caching is a technique where..."}},
326+
{Role: "user", Content: types.Content{Raw: "That's very helpful, thank you!"}},
327327
},
328328
},
329329
},
@@ -443,15 +443,15 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
443443
b.Run(fmt.Sprintf("messages_%d_length_%d", scenario.messageCount, scenario.messageLength), func(b *testing.B) {
444444
// Generate messages for this scenario
445445
messages := make([]types.Message, scenario.messageCount)
446-
messages[0] = types.Message{Role: "system", Content: "You are a helpful assistant."}
446+
messages[0] = types.Message{Role: "system", Content: types.Content{Raw: "You are a helpful assistant."}}
447447

448448
for i := 1; i < scenario.messageCount; i++ {
449449
role := "user"
450450
if i%2 == 0 {
451451
role = "assistant"
452452
}
453453
content := randomPrompt(scenario.messageLength)
454-
messages[i] = types.Message{Role: role, Content: content}
454+
messages[i] = types.Message{Role: role, Content: types.Content{Raw: content}}
455455
}
456456

457457
pod := &types.PodMetrics{

pkg/epp/scheduling/types/types.go

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ limitations under the License.
1717
package types
1818

1919
import (
20+
"encoding/json"
21+
"errors"
2022
"fmt"
23+
"strings"
2124

2225
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2326
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
@@ -48,7 +51,7 @@ func (r *LLMRequest) String() string {
4851

4952
// LLMRequestBody contains the request-body fields that we parse out as user input,
5053
// to be used in forming scheduling decisions.
51-
// An LLMRequestBody must contain exactly one of CompletionsRequest or ChatCompletionsRequest.
54+
// An LLMRequestBody must contain exactly one of CompletionsRequest,ChatCompletionsRequest or MultiModalChatCompletions.
5255
type LLMRequestBody struct {
5356
// CompletionsRequest is the representation of the OpenAI /v1/completions request body.
5457
Completions *CompletionsRequest `json:"completions,omitempty"`
@@ -97,16 +100,75 @@ func (r *ChatCompletionsRequest) String() string {
97100

98101
messagesLen := 0
99102
for _, msg := range r.Messages {
100-
messagesLen += len(msg.Content)
103+
messagesLen += len(msg.Content.PlainText())
101104
}
102-
103105
return fmt.Sprintf("{MessagesLength: %d}", messagesLen)
104106
}
105107

106108
// Message represents a single message in a chat-completions request.
107109
type Message struct {
108-
Role string
109-
Content string // TODO: support multi-modal content
110+
// Role is the message Role, optional values are 'user', 'assistant', ...
111+
Role string `json:"role,omitempty"`
112+
// Content defines text of this message
113+
Content Content `json:"content,omitempty"`
114+
}
115+
116+
type Content struct {
117+
Raw string
118+
Structured []ContentBlock
119+
}
120+
121+
type ContentBlock struct {
122+
Type string `json:"type"`
123+
Text string `json:"text,omitempty"`
124+
ImageURL ImageBlock `json:"image_url,omitempty"`
125+
}
126+
127+
type ImageBlock struct {
128+
Url string `json:"url,omitempty"`
129+
}
130+
131+
// UnmarshalJSON allow use both format
132+
func (mc *Content) UnmarshalJSON(data []byte) error {
133+
// Raw format
134+
var str string
135+
if err := json.Unmarshal(data, &str); err == nil {
136+
mc.Raw = str
137+
return nil
138+
}
139+
140+
// Block format
141+
var blocks []ContentBlock
142+
if err := json.Unmarshal(data, &blocks); err == nil {
143+
mc.Structured = blocks
144+
return nil
145+
}
146+
147+
return errors.New("content format not supported")
148+
}
149+
150+
func (mc Content) MarshalJSON() ([]byte, error) {
151+
if mc.Raw != "" {
152+
return json.Marshal(mc.Raw)
153+
}
154+
if mc.Structured != nil {
155+
return json.Marshal(mc.Structured)
156+
}
157+
return json.Marshal("")
158+
}
159+
160+
func (mc Content) PlainText() string {
161+
if mc.Raw != "" {
162+
return mc.Raw
163+
}
164+
var sb strings.Builder
165+
for _, block := range mc.Structured {
166+
if block.Type == "text" {
167+
sb.WriteString(block.Text)
168+
sb.WriteString(" ")
169+
}
170+
}
171+
return sb.String()
110172
}
111173

112174
type Pod interface {

pkg/epp/util/request/body_test.go

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,58 @@ func TestExtractRequestData(t *testing.T) {
5858
want: &types.LLMRequestBody{
5959
ChatCompletions: &types.ChatCompletionsRequest{
6060
Messages: []types.Message{
61-
{Role: "system", Content: "this is a system message"},
62-
{Role: "user", Content: "hello"},
61+
{Role: "system", Content: types.Content{Raw: "this is a system message"}},
62+
{Role: "user", Content: types.Content{Raw: "hello"}},
63+
},
64+
},
65+
},
66+
},
67+
{
68+
name: "chat completions request body with multi-modal content",
69+
body: map[string]any{
70+
"model": "test",
71+
"messages": []any{
72+
map[string]any{
73+
"role": "system",
74+
"content": []map[string]any{
75+
{
76+
"type": "text",
77+
"text": "Describe this image in one sentence.",
78+
},
79+
},
80+
},
81+
map[string]any{
82+
"role": "user",
83+
"content": []map[string]any{
84+
{
85+
"type": "image_url",
86+
"image_url": map[string]any{
87+
"url": "https://example.com/images/dui.jpg.",
88+
},
89+
},
90+
},
91+
},
92+
},
93+
},
94+
want: &types.LLMRequestBody{
95+
ChatCompletions: &types.ChatCompletionsRequest{
96+
Messages: []types.Message{
97+
{Role: "system", Content: types.Content{
98+
Structured: []types.ContentBlock{
99+
{
100+
Text: "Describe this image in one sentence.",
101+
Type: "text",
102+
},
103+
},
104+
}},
105+
{Role: "user", Content: types.Content{
106+
Structured: []types.ContentBlock{
107+
{
108+
Type: "image_url",
109+
ImageURL: types.ImageBlock{Url: "https://example.com/images/dui.jpg."},
110+
},
111+
},
112+
}},
63113
},
64114
},
65115
},
@@ -81,7 +131,7 @@ func TestExtractRequestData(t *testing.T) {
81131
},
82132
want: &types.LLMRequestBody{
83133
ChatCompletions: &types.ChatCompletionsRequest{
84-
Messages: []types.Message{{Role: "user", Content: "hello"}},
134+
Messages: []types.Message{{Role: "user", Content: types.Content{Raw: "hello"}}},
85135
Tools: []any{map[string]any{"type": "function"}},
86136
Documents: []any{map[string]any{"content": "doc"}},
87137
ChatTemplate: "custom template",

0 commit comments

Comments
 (0)