Skip to content

Commit 49617c7

Browse files
authored
feat: Unify the SSE processing logic (alibaba#1800)
1 parent 53a015d commit 49617c7

File tree

5 files changed

+161
-199
lines changed

5 files changed

+161
-199
lines changed

plugins/wasm-go/extensions/ai-proxy/main.go

+45-9
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,21 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
102102

103103
// Always remove the Accept-Encoding header to prevent the LLM from sending compressed responses,
104104
// allowing plugins to inspect or modify the response correctly
105-
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
105+
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
106106

107107
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
108108
// Set the apiToken for the current request.
109109
providerConfig.SetApiTokenInUse(ctx, log)
110110

111111
err := handler.OnRequestHeaders(ctx, apiName, log)
112112
if err != nil {
113-
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
113+
_ = util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
114114
return types.ActionContinue
115115
}
116116

117117
hasRequestBody := wrapper.HasRequestBody()
118118
if hasRequestBody {
119-
proxywasm.RemoveHttpRequestHeader("Content-Length")
119+
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
120120
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
121121
// Delay the header processing to allow changing in OnRequestBody
122122
return types.HeaderStopIteration
@@ -143,7 +143,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
143143

144144
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
145145
if settingErr != nil {
146-
util.ErrorHandler(
146+
_ = util.ErrorHandler(
147147
"ai-proxy.proc_req_body_failed",
148148
fmt.Errorf("failed to replace request body by custom settings: %v", settingErr),
149149
)
@@ -156,7 +156,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
156156
if err == nil {
157157
return action
158158
}
159-
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
159+
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
160160
}
161161
return types.ActionContinue
162162
}
@@ -205,7 +205,11 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
205205

206206
checkStream(ctx, log)
207207
_, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler)
208-
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
208+
var needHandleStreamingBody bool
209+
_, needHandleStreamingBody = activeProvider.(provider.StreamingResponseBodyHandler)
210+
if !needHandleStreamingBody {
211+
_, needHandleStreamingBody = activeProvider.(provider.StreamingEventHandler)
212+
}
209213
if !needHandleBody && !needHandleStreamingBody {
210214
ctx.DontReadResponseBody()
211215
} else if !needHandleStreamingBody {
@@ -224,7 +228,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
224228
}
225229

226230
log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType())
227-
log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
231+
log.Debugf("[onStreamingResponseBody] isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
228232

229233
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
230234
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
@@ -234,6 +238,38 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
234238
}
235239
return chunk
236240
}
241+
if handler, ok := activeProvider.(provider.StreamingEventHandler); ok {
242+
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
243+
events := provider.ExtractStreamingEvents(ctx, chunk, log)
244+
log.Debugf("[onStreamingResponseBody] %d events received", len(events))
245+
if len(events) == 0 {
246+
// No events are extracted, return the original chunk
247+
return chunk
248+
}
249+
var responseBuilder strings.Builder
250+
for _, event := range events {
251+
log.Debugf("processing event: %v", event)
252+
253+
if event.IsEndData() {
254+
responseBuilder.WriteString(event.ToHttpString())
255+
continue
256+
}
257+
258+
outputEvents, err := handler.OnStreamingEvent(ctx, apiName, event, log)
259+
if err != nil {
260+
log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk)
261+
return chunk
262+
}
263+
if outputEvents == nil || len(outputEvents) == 0 {
264+
responseBuilder.WriteString(event.ToHttpString())
265+
} else {
266+
for _, outputEvent := range outputEvents {
267+
responseBuilder.WriteString(outputEvent.ToHttpString())
268+
}
269+
}
270+
}
271+
return []byte(responseBuilder.String())
272+
}
237273
return chunk
238274
}
239275

@@ -251,11 +287,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
251287
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
252288
body, err := handler.TransformResponseBody(ctx, apiName, body, log)
253289
if err != nil {
254-
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
290+
_ = util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
255291
return types.ActionContinue
256292
}
257293
if err = provider.ReplaceResponseBody(body, log); err != nil {
258-
util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
294+
_ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
259295
}
260296
}
261297
return types.ActionContinue

plugins/wasm-go/extensions/ai-proxy/provider/model.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,18 @@ func (m *functionCall) IsEmpty() bool {
278278
return m.Name == "" && m.Arguments == ""
279279
}
280280

281-
type streamEvent struct {
281+
type StreamEvent struct {
282282
Id string `json:"id"`
283283
Event string `json:"event"`
284284
Data string `json:"data"`
285285
HttpStatus string `json:"http_status"`
286286
}
287287

288-
func (e *streamEvent) setValue(key, value string) {
288+
func (e *StreamEvent) IsEndData() bool {
289+
return e.Data == streamEndDataValue
290+
}
291+
292+
func (e *StreamEvent) SetValue(key, value string) {
289293
switch key {
290294
case streamEventIdItemKey:
291295
e.Id = value
@@ -300,6 +304,10 @@ func (e *streamEvent) setValue(key, value string) {
300304
}
301305
}
302306

307+
func (e *StreamEvent) ToHttpString() string {
308+
return fmt.Sprintf("%s %s\n\n", streamDataItemKey, e.Data)
309+
}
310+
303311
// https://platform.openai.com/docs/guides/images
304312
type imageGenerationRequest struct {
305313
Model string `json:"model"`

plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go

+8-79
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,12 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
102102
}()
103103
if err != nil {
104104
log.Errorf("failed to load context file: %v", err)
105-
util.ErrorHandler("ai-proxy.moonshot.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
105+
_ = util.ErrorHandler("ai-proxy.moonshot.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
106106
return
107107
}
108108
err = m.performChatCompletion(ctx, content, request, log)
109109
if err != nil {
110-
util.ErrorHandler("ai-proxy.moonshot.insert_ctx_failed", fmt.Errorf("failed to perform chat completion: %v", err))
110+
_ = util.ErrorHandler("ai-proxy.moonshot.insert_ctx_failed", fmt.Errorf("failed to perform chat completion: %v", err))
111111
}
112112
}, log)
113113
if err == nil {
@@ -161,100 +161,29 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba
161161
}
162162
}
163163

164-
func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
164+
func (m *moonshotProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) {
165165
if name != ApiNameChatCompletion {
166-
return chunk, nil
167-
}
168-
receivedBody := chunk
169-
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
170-
receivedBody = append(bufferedStreamingBody, chunk...)
171-
}
172-
173-
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
174-
175-
defer func() {
176-
if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) {
177-
// Just in case the received chunk is not a complete event.
178-
ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:])
179-
} else {
180-
ctx.SetContext(ctxKeyStreamingBody, nil)
181-
}
182-
}()
183-
184-
var responseBuilder strings.Builder
185-
currentKey := ""
186-
currentEvent := &streamEvent{}
187-
i, length := 0, len(receivedBody)
188-
for i = 0; i < length; i++ {
189-
ch := receivedBody[i]
190-
if ch != '\n' {
191-
if lineStartIndex == -1 {
192-
if eventStartIndex == -1 {
193-
eventStartIndex = i
194-
}
195-
lineStartIndex = i
196-
valueStartIndex = -1
197-
}
198-
if valueStartIndex == -1 {
199-
if ch == ':' {
200-
valueStartIndex = i + 1
201-
currentKey = string(receivedBody[lineStartIndex:valueStartIndex])
202-
}
203-
} else if valueStartIndex == i && ch == ' ' {
204-
// Skip leading spaces in data.
205-
valueStartIndex = i + 1
206-
}
207-
continue
208-
}
209-
210-
if lineStartIndex != -1 {
211-
value := string(receivedBody[valueStartIndex:i])
212-
currentEvent.setValue(currentKey, value)
213-
} else {
214-
// Extra new line. The current event is complete.
215-
log.Debugf("processing event: %v", currentEvent)
216-
m.convertStreamEvent(&responseBuilder, currentEvent, log)
217-
// Reset event parsing state.
218-
eventStartIndex = -1
219-
currentEvent = &streamEvent{}
220-
}
221-
222-
// Reset line parsing state.
223-
lineStartIndex = -1
224-
valueStartIndex = -1
225-
currentKey = ""
226-
}
227-
228-
modifiedResponseChunk := responseBuilder.String()
229-
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
230-
return []byte(modifiedResponseChunk), nil
231-
}
232-
233-
func (m *moonshotProvider) convertStreamEvent(responseBuilder *strings.Builder, event *streamEvent, log wrapper.Log) error {
234-
if event.Data == streamEndDataValue {
235-
m.appendStreamEvent(responseBuilder, event)
236-
return nil
166+
return nil, nil
237167
}
238168

239169
if gjson.Get(event.Data, "choices.0.usage").Exists() {
240170
usageStr := gjson.Get(event.Data, "choices.0.usage").Raw
241171
newData, err := sjson.Delete(event.Data, "choices.0.usage")
242172
if err != nil {
243173
log.Errorf("convert usage event error: %v", err)
244-
return err
174+
return nil, err
245175
}
246176
newData, err = sjson.SetRaw(newData, "usage", usageStr)
247177
if err != nil {
248178
log.Errorf("convert usage event error: %v", err)
249-
return err
179+
return nil, err
250180
}
251181
event.Data = newData
252182
}
253-
m.appendStreamEvent(responseBuilder, event)
254-
return nil
183+
return []StreamEvent{event}, nil
255184
}
256185

257-
func (m *moonshotProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
186+
func (m *moonshotProvider) appendStreamEvent(responseBuilder *strings.Builder, event *StreamEvent) {
258187
responseBuilder.WriteString(streamDataItemKey)
259188
responseBuilder.WriteString(event.Data)
260189
responseBuilder.WriteString("\n\n")

plugins/wasm-go/extensions/ai-proxy/provider/provider.go

+79
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ type StreamingResponseBodyHandler interface {
149149
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
150150
}
151151

152+
type StreamingEventHandler interface {
153+
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error)
154+
}
155+
152156
type ApiNameHandler interface {
153157
GetApiName(path string) ApiName
154158
}
@@ -575,6 +579,81 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
575579
return ""
576580
}
577581

582+
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) []StreamEvent {
583+
body := chunk
584+
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
585+
body = append(bufferedStreamingBody, chunk...)
586+
}
587+
588+
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
589+
590+
defer func() {
591+
if eventStartIndex >= 0 && eventStartIndex < len(body) {
592+
// Just in case the received chunk is not a complete event.
593+
ctx.SetContext(ctxKeyStreamingBody, body[eventStartIndex:])
594+
} else {
595+
ctx.SetContext(ctxKeyStreamingBody, nil)
596+
}
597+
}()
598+
599+
// Sample Qwen event response:
600+
//
601+
// event:result
602+
// :HTTP_STATUS/200
603+
// data:{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":116,"input_tokens":114,"output_tokens":2},"request_id":"71689cfc-1f42-9949-86e8-9563b7f832b1"}
604+
//
605+
// event:error
606+
// :HTTP_STATUS/400
607+
// data:{"code":"InvalidParameter","message":"Preprocessor error","request_id":"0cbe6006-faec-9854-bf8b-c906d75c3bd8"}
608+
//
609+
610+
var events []StreamEvent
611+
612+
currentKey := ""
613+
currentEvent := &StreamEvent{}
614+
i, length := 0, len(body)
615+
for i = 0; i < length; i++ {
616+
ch := body[i]
617+
if ch != '\n' {
618+
if lineStartIndex == -1 {
619+
if eventStartIndex == -1 {
620+
eventStartIndex = i
621+
}
622+
lineStartIndex = i
623+
valueStartIndex = -1
624+
}
625+
if valueStartIndex == -1 {
626+
if ch == ':' {
627+
valueStartIndex = i + 1
628+
currentKey = string(body[lineStartIndex:valueStartIndex])
629+
}
630+
} else if valueStartIndex == i && ch == ' ' {
631+
// Skip leading spaces in data.
632+
valueStartIndex = i + 1
633+
}
634+
continue
635+
}
636+
637+
if lineStartIndex != -1 {
638+
value := string(body[valueStartIndex:i])
639+
currentEvent.SetValue(currentKey, value)
640+
} else {
641+
// Extra new line. The current event is complete.
642+
events = append(events, *currentEvent)
643+
// Reset event parsing state.
644+
eventStartIndex = -1
645+
currentEvent = &StreamEvent{}
646+
}
647+
648+
// Reset line parsing state.
649+
lineStartIndex = -1
650+
valueStartIndex = -1
651+
currentKey = ""
652+
}
653+
654+
return events
655+
}
656+
578657
func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool {
579658
_, exist := c.capabilities[string(apiName)]
580659
return exist

0 commit comments

Comments
 (0)