diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index e2498b9a82..dbba80355b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -27,6 +27,7 @@ const ( qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation" qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding" qwenCompatiblePath = "/compatible-mode/v1/chat/completions" + qwenBailianPath = "/api/v1/apps" qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation" qwenTopPMin = 0.000001 @@ -71,7 +72,8 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName } util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) - if m.config.qwenEnableCompatible { + if m.config.IsOriginal() { + } else if m.config.qwenEnableCompatible { util.OverwriteRequestPathHeader(headers, qwenCompatiblePath) } else if apiName == ApiNameChatCompletion { util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath) @@ -762,6 +764,7 @@ func (m *qwenProvider) GetApiName(path string) ApiName { switch { case strings.Contains(path, qwenChatCompletionPath), strings.Contains(path, qwenMultimodalGenerationPath), + strings.Contains(path, qwenBailianPath), strings.Contains(path, qwenCompatiblePath): return ApiNameChatCompletion case strings.Contains(path, qwenTextEmbeddingPath): diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index 4fa6e07c68..0e0a747fa1 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -384,26 +384,20 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log ctx.DontReadResponseBody() return types.ActionContinue } - headers, err := proxywasm.GetHttpResponseHeaders() - if err != nil { - log.Warnf("failed to get response headers: %v", err) - return types.ActionContinue - } - hdsMap := convertHeaders(headers) - if !strings.Contains(strings.Join(hdsMap[":status"], ";"), "200") { + statusCode, _ := proxywasm.GetHttpResponseHeader(":status") + if statusCode != "200" { log.Debugf("response is not 200, skip response body check") ctx.DontReadResponseBody() return types.ActionContinue } - ctx.SetContext("headers", hdsMap) return types.HeaderStopIteration } func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { log.Debugf("checking response body...") startTime := time.Now().UnixMilli() - hdsMap := ctx.GetContext("headers").(map[string][]string) - isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") + contentType, _ := proxywasm.GetHttpResponseHeader("content-type") + isStreamingResponse := strings.Contains(contentType, "event-stream") model := ctx.GetStringContext("requestModel", "unknown") var content string if isStreamingResponse { diff --git a/plugins/wasm-go/extensions/ai-statistics/main.go b/plugins/wasm-go/extensions/ai-statistics/main.go index 1c8765638f..363f59194e 100644 --- a/plugins/wasm-go/extensions/ai-statistics/main.go +++ b/plugins/wasm-go/extensions/ai-statistics/main.go @@ -303,39 +303,33 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag // fetches the tracing span value from the specified source. func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) { for _, attribute := range config.attributes { - var key, value string - var err error + var key string + var value interface{} if source == attribute.ValueSource { key = attribute.Key switch source { case FixedValue: - log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, attribute.Value) value = attribute.Value case RequestHeader: - if value, err = proxywasm.GetHttpRequestHeader(attribute.Value); err == nil { - log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - } + value, _ = proxywasm.GetHttpRequestHeader(attribute.Value) case RequestBody: - raw := gjson.GetBytes(body, attribute.Value).Raw - if len(raw) > 2 { - value = raw[1 : len(raw)-1] - } - log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) + value = gjson.GetBytes(body, attribute.Value).Value() case ResponseHeader: - if value, err = proxywasm.GetHttpResponseHeader(attribute.Value); err == nil { - log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - } + value, _ = proxywasm.GetHttpResponseHeader(attribute.Value) case ResponseStreamingBody: value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log) - log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) case ResponseBody: - value = gjson.GetBytes(body, attribute.Value).String() - log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) + value = gjson.GetBytes(body, attribute.Value).Value() default: } + log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, value) if attribute.ApplyToLog { ctx.SetUserAttribute(key, value) } + // for metrics + if key == Model || key == InputToken || key == OutputToken { + ctx.SetContext(key, value) + } if attribute.ApplyToSpan { setSpanAttribute(key, value, log) } @@ -343,14 +337,14 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so } } -func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) string { +func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} { chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) - var value string + var value interface{} if rule == RuleFirst { for _, chunk := range chunks { jsonObj := gjson.GetBytes(chunk, jsonPath) if jsonObj.Exists() { - value = jsonObj.String() + value = jsonObj.Value() break } } @@ -358,17 +352,19 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l for _, chunk := range chunks { jsonObj := gjson.GetBytes(chunk, jsonPath) if jsonObj.Exists() { - value = jsonObj.String() + value = jsonObj.Value() } } } else if rule == RuleAppend { // extract llm response + var strValue string for _, chunk := range chunks { jsonObj := gjson.GetBytes(chunk, jsonPath) if jsonObj.Exists() { - value += jsonObj.String() + strValue += jsonObj.String() } } + value = strValue } else { log.Errorf("unsupported rule type: %s", rule) } @@ -376,10 +372,10 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l } // Set the tracing span with value. -func setSpanAttribute(key, value string, log wrapper.Log) { +func setSpanAttribute(key string, value interface{}, log wrapper.Log) { if value != "" { traceSpanTag := wrapper.TraceSpanTagPrefix + key - if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(value)); e != nil { + if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil { log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e) } } else { @@ -388,36 +384,84 @@ func setSpanAttribute(key, value string, log wrapper.Log) { } func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) { - route := ctx.GetContext(RouteName).(string) - cluster := ctx.GetContext(ClusterName).(string) // Generate usage metrics - var model string - var inputToken, outputToken int64 + var ok bool + var route, cluster, model string + var inputToken, outputToken uint64 + route, ok = ctx.GetContext(RouteName).(string) + if !ok { + log.Warnf("RouteName typd assert failed, skip metric record") + return + } + cluster, ok = ctx.GetContext(ClusterName).(string) + if !ok { + log.Warnf("ClusterName typd assert failed, skip metric record") + return + } if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil { log.Warnf("get usage information failed, skip metric record") return } - model = ctx.GetUserAttribute(Model).(string) - inputToken = ctx.GetUserAttribute(InputToken).(int64) - outputToken = ctx.GetUserAttribute(OutputToken).(int64) + model, ok = ctx.GetUserAttribute(Model).(string) + if !ok { + log.Warnf("Model typd assert failed, skip metric record") + return + } + inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken)) + if !ok { + log.Warnf("InputToken typd assert failed, skip metric record") + return + } + outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken)) + if !ok { + log.Warnf("OutputToken typd assert failed, skip metric record") + return + } if inputToken == 0 || outputToken == 0 { log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record") return } - config.incrementCounter(generateMetricName(route, cluster, model, InputToken), uint64(inputToken)) - config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), uint64(outputToken)) + config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken) + config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken) // Generate duration metrics - var llmFirstTokenDuration, llmServiceDuration int64 + var llmFirstTokenDuration, llmServiceDuration uint64 // Is stream response if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil { - llmFirstTokenDuration = ctx.GetUserAttribute(LLMFirstTokenDuration).(int64) - config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), uint64(llmFirstTokenDuration)) + llmFirstTokenDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMFirstTokenDuration)) + if !ok { + log.Warnf("LLMFirstTokenDuration typd assert failed") + return + } + config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration) config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1) } if ctx.GetUserAttribute(LLMServiceDuration) != nil { - llmServiceDuration = ctx.GetUserAttribute(LLMServiceDuration).(int64) - config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), uint64(llmServiceDuration)) + llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration)) + if !ok { + log.Warnf("LLMServiceDuration typd assert failed") + return + } + config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration) config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1) } } + +func convertToUInt(val interface{}) (uint64, bool) { + switch v := val.(type) { + case float32: + return uint64(v), true + case float64: + return uint64(v), true + case int32: + return uint64(v), true + case int64: + return uint64(v), true + case uint32: + return uint64(v), true + case uint64: + return v, true + default: + return 0, false + } +}