|
1 | 1 | package main
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "bytes" |
4 | 5 | "encoding/json"
|
5 | 6 | "errors"
|
6 | 7 | "fmt"
|
@@ -215,35 +216,51 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, da
|
215 | 216 | if chatMode == ChatModeNone || chatMode == ChatModeAdmin {
|
216 | 217 | return data
|
217 | 218 | }
|
| 219 | + var inputToken, outputToken int64 |
| 220 | + var consumer string |
| 221 | + if inputToken, outputToken, ok := getUsage(data); ok { |
| 222 | + ctx.SetContext("input_token", inputToken) |
| 223 | + ctx.SetContext("output_token", outputToken) |
| 224 | + } |
| 225 | + |
218 | 226 | // chat completion mode
|
219 | 227 | if !endOfStream {
|
220 | 228 | return data
|
221 | 229 | }
|
222 |
| - inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"}) |
223 |
| - if err != nil { |
224 |
| - return data |
225 |
| - } |
226 |
| - outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"}) |
227 |
| - if err != nil { |
228 |
| - return data |
229 |
| - } |
230 |
| - inputToken, err := strconv.Atoi(string(inputTokenStr)) |
231 |
| - if err != nil { |
232 |
| - return data |
233 |
| - } |
234 |
| - outputToken, err := strconv.Atoi(string(outputTokenStr)) |
235 |
| - if err != nil { |
| 230 | + |
| 231 | + if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil || ctx.GetContext("consumer") == nil { |
236 | 232 | return data
|
237 | 233 | }
|
238 |
| - consumer, ok := ctx.GetContext("consumer").(string) |
239 |
| - if ok { |
240 |
| - totalToken := int(inputToken + outputToken) |
241 |
| - log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken) |
242 |
| - config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil) |
243 |
| - } |
| 234 | + |
| 235 | + inputToken = ctx.GetContext("input_token").(int64) |
| 236 | + outputToken = ctx.GetContext("output_token").(int64) |
| 237 | + consumer = ctx.GetContext("consumer").(string) |
| 238 | + totalToken := int(inputToken + outputToken) |
| 239 | + log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken) |
| 240 | + config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil) |
244 | 241 | return data
|
245 | 242 | }
|
246 | 243 |
|
| 244 | +func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) { |
| 245 | + chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) |
| 246 | + for _, chunk := range chunks { |
| 247 | + // the feature strings are used to identify the usage data, like: |
| 248 | + // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}} |
| 249 | + if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) { |
| 250 | + continue |
| 251 | + } |
| 252 | + inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens") |
| 253 | + outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens") |
| 254 | + if inputTokenObj.Exists() && outputTokenObj.Exists() { |
| 255 | + inputTokenUsage = inputTokenObj.Int() |
| 256 | + outputTokenUsage = outputTokenObj.Int() |
| 257 | + ok = true |
| 258 | + return |
| 259 | + } |
| 260 | + } |
| 261 | + return |
| 262 | +} |
| 263 | + |
247 | 264 | func deniedNoKeyAuthData() types.Action {
|
248 | 265 | util.SendResponse(http.StatusUnauthorized, "ai-quota.no_key", "text/plain", "Request denied by ai quota check. No Key Authentication information found.")
|
249 | 266 | return types.ActionContinue
|
|
0 commit comments