Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
}

// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
// hash(0) is the hash of the model name, since different models generally don't share prefix cache.
// hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache.
// For block i, hash(i) = hash(block i content, hash(i-1)).
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
Expand Down Expand Up @@ -286,6 +286,10 @@ func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize i
// Add the model to the first block hash so that different models have different hashes even with the same body.
h := xxhash.New()
_, _ = h.Write([]byte(request.TargetModel))
if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" {
_, _ = h.Write([]byte(cacheSalt))
}

prevBlockHash := BlockHash(h.Sum64())
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
h.Reset()
Expand Down
16 changes: 16 additions & 0 deletions pkg/epp/scheduling/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,27 @@ type LLMRequestBody struct {
ChatCompletions *ChatCompletionsRequest `json:"chat_completions,omitempty"`
}

func (r *LLMRequestBody) CacheSalt() string {
if r.ChatCompletions == nil && r.Completions == nil {
return ""
}

if r.ChatCompletions != nil {
return r.ChatCompletions.CacheSalt
}

return r.Completions.CacheSalt
}

// CompletionsRequest is a structured representation of the fields we parse out of the
// /v1/completions request body.
// This struct includes fields usable for plugins and scheduling decisions - and not the entire
// API spec.
type CompletionsRequest struct {
// Prompt is the prompt that was sent in the request body.
Prompt string `json:"prompt,omitempty"`
// CacheSalt is an optional request parameter to isolate prefix caches for security reasons.
CacheSalt string `json:"cache_salt,omitempty"`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, did you test with both completion and chatcompletions request with vllm and make sure the parsing here works?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I have checked the request body definition in VLLM:

Both of them have cache_salt. So I sent the below curl:

for completion:

curl -i ${IP}:${PORT}/v1/completions -H 'Content-Type: application/json' -d '{
"model": "food-review-1",
"prompt": "Write as if you were a critic: San Francisco",
"max_tokens": 100,
"cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
"temperature": 0
}'

parse result:
image

for chatcompletions:

curl -X POST -i ${IP}:${PORT}/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
        "model": "food-review-1",
        "max_tokens": 100,
        "temperature": 0,
        "cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
        "messages": [
          {
            "role": "developer",
            "content": "You are a helpful assistant."
          },
          {
            "role": "user",
            "content": "Linux is said to be an open source kernel because "
          }
        ]
  }'

parse result:
image

}

func (r *CompletionsRequest) String() string {
Expand All @@ -88,6 +102,8 @@ type ChatCompletionsRequest struct {
ContinueFinalMessage bool `json:"continue_final_message,omitempty"`
AddGenerationPrompt bool `json:"add_generation_prompt,omitempty"`
ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"`
// CacheSalt is an optional request parameter to isolate prefix caches for security reasons.
CacheSalt string `json:"cache_salt,omitempty"`
}

func (r *ChatCompletionsRequest) String() string {
Expand Down
38 changes: 38 additions & 0 deletions pkg/epp/util/request/body_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,44 @@ func TestExtractRequestData(t *testing.T) {
},
wantErr: true,
},
{
name: "completions request with cache_salt",
body: map[string]any{
"model": "test",
"prompt": "test prompt",
"cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
},
want: &types.LLMRequestBody{
Completions: &types.CompletionsRequest{
Prompt: "test prompt",
CacheSalt: "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
},
},
},
{
name: "chat completions request with cache_salt",
body: map[string]any{
"model": "test",
"messages": []any{
map[string]any{
"role": "system", "content": "this is a system message",
},
map[string]any{
"role": "user", "content": "hello",
},
},
"cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
},
want: &types.LLMRequestBody{
ChatCompletions: &types.ChatCompletionsRequest{
Messages: []types.Message{
{Role: "system", Content: "this is a system message"},
{Role: "user", Content: "hello"},
},
CacheSalt: "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
},
},
},
}

for _, tt := range tests {
Expand Down