Skip to content

Commit

Permalink
qwen bailian compatible bug fix (#1597)
Browse files Browse the repository at this point in the history
  • Loading branch information
rinfx authored Dec 17, 2024
1 parent 2a200cd commit 2f5709a
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 49 deletions.
5 changes: 4 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const (
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
qwenCompatiblePath = "/compatible-mode/v1/chat/completions"
qwenBailianPath = "/api/v1/apps"
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"

qwenTopPMin = 0.000001
Expand Down Expand Up @@ -71,7 +72,8 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
}
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))

if m.config.qwenEnableCompatible {
if m.config.IsOriginal() {
} else if m.config.qwenEnableCompatible {
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
} else if apiName == ApiNameChatCompletion {
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
Expand Down Expand Up @@ -762,6 +764,7 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
switch {
case strings.Contains(path, qwenChatCompletionPath),
strings.Contains(path, qwenMultimodalGenerationPath),
strings.Contains(path, qwenBailianPath),
strings.Contains(path, qwenCompatiblePath):
return ApiNameChatCompletion
case strings.Contains(path, qwenTextEmbeddingPath):
Expand Down
14 changes: 4 additions & 10 deletions plugins/wasm-go/extensions/ai-security-guard/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,26 +384,20 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
ctx.DontReadResponseBody()
return types.ActionContinue
}
headers, err := proxywasm.GetHttpResponseHeaders()
if err != nil {
log.Warnf("failed to get response headers: %v", err)
return types.ActionContinue
}
hdsMap := convertHeaders(headers)
if !strings.Contains(strings.Join(hdsMap[":status"], ";"), "200") {
statusCode, _ := proxywasm.GetHttpResponseHeader(":status")
if statusCode != "200" {
log.Debugf("response is not 200, skip response body check")
ctx.DontReadResponseBody()
return types.ActionContinue
}
ctx.SetContext("headers", hdsMap)
return types.HeaderStopIteration
}

func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("checking response body...")
startTime := time.Now().UnixMilli()
hdsMap := ctx.GetContext("headers").(map[string][]string)
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
isStreamingResponse := strings.Contains(contentType, "event-stream")
model := ctx.GetStringContext("requestModel", "unknown")
var content string
if isStreamingResponse {
Expand Down
120 changes: 82 additions & 38 deletions plugins/wasm-go/extensions/ai-statistics/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,83 +303,79 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag
// fetches the tracing span value from the specified source.
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) {
for _, attribute := range config.attributes {
var key, value string
var err error
var key string
var value interface{}
if source == attribute.ValueSource {
key = attribute.Key
switch source {
case FixedValue:
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, attribute.Value)
value = attribute.Value
case RequestHeader:
if value, err = proxywasm.GetHttpRequestHeader(attribute.Value); err == nil {
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
}
value, _ = proxywasm.GetHttpRequestHeader(attribute.Value)
case RequestBody:
raw := gjson.GetBytes(body, attribute.Value).Raw
if len(raw) > 2 {
value = raw[1 : len(raw)-1]
}
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
value = gjson.GetBytes(body, attribute.Value).Value()
case ResponseHeader:
if value, err = proxywasm.GetHttpResponseHeader(attribute.Value); err == nil {
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
}
value, _ = proxywasm.GetHttpResponseHeader(attribute.Value)
case ResponseStreamingBody:
value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log)
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
case ResponseBody:
value = gjson.GetBytes(body, attribute.Value).String()
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
value = gjson.GetBytes(body, attribute.Value).Value()
default:
}
log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, value)
if attribute.ApplyToLog {
ctx.SetUserAttribute(key, value)
}
// for metrics
if key == Model || key == InputToken || key == OutputToken {
ctx.SetContext(key, value)
}
if attribute.ApplyToSpan {
setSpanAttribute(key, value, log)
}
}
}
}

func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) string {
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
var value string
var value interface{}
if rule == RuleFirst {
for _, chunk := range chunks {
jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() {
value = jsonObj.String()
value = jsonObj.Value()
break
}
}
} else if rule == RuleReplace {
for _, chunk := range chunks {
jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() {
value = jsonObj.String()
value = jsonObj.Value()
}
}
} else if rule == RuleAppend {
// extract llm response
var strValue string
for _, chunk := range chunks {
jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() {
value += jsonObj.String()
strValue += jsonObj.String()
}
}
value = strValue
} else {
log.Errorf("unsupported rule type: %s", rule)
}
return value
}

// Set the tracing span with value.
func setSpanAttribute(key, value string, log wrapper.Log) {
func setSpanAttribute(key string, value interface{}, log wrapper.Log) {
if value != "" {
traceSpanTag := wrapper.TraceSpanTagPrefix + key
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(value)); e != nil {
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil {
log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e)
}
} else {
Expand All @@ -388,36 +384,84 @@ func setSpanAttribute(key, value string, log wrapper.Log) {
}

func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) {
route := ctx.GetContext(RouteName).(string)
cluster := ctx.GetContext(ClusterName).(string)
// Generate usage metrics
var model string
var inputToken, outputToken int64
var ok bool
var route, cluster, model string
var inputToken, outputToken uint64
route, ok = ctx.GetContext(RouteName).(string)
if !ok {
log.Warnf("RouteName typd assert failed, skip metric record")
return
}
cluster, ok = ctx.GetContext(ClusterName).(string)
if !ok {
log.Warnf("ClusterName typd assert failed, skip metric record")
return
}
if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil {
log.Warnf("get usage information failed, skip metric record")
return
}
model = ctx.GetUserAttribute(Model).(string)
inputToken = ctx.GetUserAttribute(InputToken).(int64)
outputToken = ctx.GetUserAttribute(OutputToken).(int64)
model, ok = ctx.GetUserAttribute(Model).(string)
if !ok {
log.Warnf("Model typd assert failed, skip metric record")
return
}
inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken))
if !ok {
log.Warnf("InputToken typd assert failed, skip metric record")
return
}
outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken))
if !ok {
log.Warnf("OutputToken typd assert failed, skip metric record")
return
}
if inputToken == 0 || outputToken == 0 {
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
return
}
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), uint64(inputToken))
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), uint64(outputToken))
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken)
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken)

// Generate duration metrics
var llmFirstTokenDuration, llmServiceDuration int64
var llmFirstTokenDuration, llmServiceDuration uint64
// Is stream response
if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil {
llmFirstTokenDuration = ctx.GetUserAttribute(LLMFirstTokenDuration).(int64)
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), uint64(llmFirstTokenDuration))
llmFirstTokenDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMFirstTokenDuration))
if !ok {
log.Warnf("LLMFirstTokenDuration typd assert failed")
return
}
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration)
config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1)
}
if ctx.GetUserAttribute(LLMServiceDuration) != nil {
llmServiceDuration = ctx.GetUserAttribute(LLMServiceDuration).(int64)
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), uint64(llmServiceDuration))
llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
if !ok {
log.Warnf("LLMServiceDuration typd assert failed")
return
}
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration)
config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
}
}

func convertToUInt(val interface{}) (uint64, bool) {
switch v := val.(type) {
case float32:
return uint64(v), true
case float64:
return uint64(v), true
case int32:
return uint64(v), true
case int64:
return uint64(v), true
case uint32:
return uint64(v), true
case uint64:
return v, true
default:
return 0, false
}
}

0 comments on commit 2f5709a

Please sign in to comment.