Skip to content

Commit 9f3d093

Browse files
feat: add max-model-len configuration and validation for context window (#82) (#85)
* feat: add max-model-len configuration and validation for context window (#82) * refactor: remove redundant check for max model length in validateContextWindow * fix: correct indentation for test entry in simulator configuration tests * test: add additional test case for simulator configuration * fix: static lint check errors * fix: update error message capitalization in validateContextWindow * fix: update error message capitalization in validateContextWindow * fix: refactored context window validation func with detailed error messages and update README --------- Signed-off-by: Mohit Pal Singh <mohit.pal.singh@outlook.com>
1 parent 7656a3c commit 9f3d093

File tree

8 files changed

+210
-6
lines changed

8 files changed

+210
-6
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ For more details see the <a href="https://docs.vllm.ai/en/stable/getting_started
9292
- `lora-modules`: a list of LoRA adapters (a list of space-separated JSON strings): '{"name": "name", "path": "lora_path", "base_model_name": "id"}', optional, empty by default
9393
- `max-loras`: maximum number of LoRAs in a single batch, optional, default is one
9494
- `max-cpu-loras`: maximum number of LoRAs to store in CPU memory, optional, must be >= than max-loras, default is max-loras
95+
- `max-model-len`: model's context window, maximum number of tokens in a single request including input and output, optional, default is 1024
9596
- `max-num-seqs`: maximum number of sequences per iteration (maximum number of inference requests that could be processed at the same time), default is 5
9697
- `mode`: the simulator mode, optional, by default `random`
9798
- `echo`: returns the same text that was sent in the request

pkg/llm-d-inference-sim/config.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ type configuration struct {
4141
// MaxNumSeqs is maximum number of sequences per iteration (the maximum
4242
// number of inference requests that could be processed at the same time)
4343
MaxNumSeqs int `yaml:"max-num-seqs"`
44+
// MaxModelLen is the model's context window, the maximum number of tokens
45+
// in a single request including input and output. Default value is 1024.
46+
MaxModelLen int `yaml:"max-model-len"`
4447
// LoraModulesString is a list of LoRA adapters as strings
4548
LoraModulesString []string `yaml:"lora-modules"`
4649
// LoraModules is a list of LoRA adapters
@@ -97,11 +100,12 @@ func (c *configuration) unmarshalLoras() error {
97100

98101
func newConfig() *configuration {
99102
return &configuration{
100-
Port: vLLMDefaultPort,
101-
MaxLoras: 1,
102-
MaxNumSeqs: 5,
103-
Mode: modeRandom,
104-
Seed: time.Now().UnixNano(),
103+
Port: vLLMDefaultPort,
104+
MaxLoras: 1,
105+
MaxNumSeqs: 5,
106+
MaxModelLen: 1024,
107+
Mode: modeRandom,
108+
Seed: time.Now().UnixNano(),
105109
}
106110
}
107111

@@ -151,6 +155,9 @@ func (c *configuration) validate() error {
151155
if c.MaxCPULoras < c.MaxLoras {
152156
return errors.New("max CPU LoRAs cannot be less than max LoRAs")
153157
}
158+
if c.MaxModelLen < 1 {
159+
return errors.New("max model len cannot be less than 1")
160+
}
154161

155162
for _, lora := range c.LoraModules {
156163
if lora.Name == "" {

pkg/llm-d-inference-sim/config_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ var _ = Describe("Simulator configuration", func() {
239239
}
240240
tests = append(tests, test)
241241

242+
test = testCase{
243+
name: "invalid max-model-len",
244+
args: []string{"cmd", "--max-model-len", "0", "--config", "../../manifests/config.yaml"},
245+
}
246+
tests = append(tests, test)
247+
242248
DescribeTable("check configurations",
243249
func(args []string, expectedConfig *configuration) {
244250
config, err := createSimConfig(args)
@@ -264,5 +270,6 @@ var _ = Describe("Simulator configuration", func() {
264270
Entry(tests[9].name, tests[9].args),
265271
Entry(tests[10].name, tests[10].args),
266272
Entry(tests[11].name, tests[11].args),
273+
Entry(tests[12].name, tests[12].args),
267274
)
268275
})

pkg/llm-d-inference-sim/request.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ type completionRequest interface {
4242
getTools() []tool
4343
// getToolChoice() returns tool choice (in chat completion)
4444
getToolChoice() string
45+
// getMaxCompletionTokens returns the maximum completion tokens requested
46+
getMaxCompletionTokens() *int64
4547
}
4648

4749
// baseCompletionRequest contains base completion request related information
@@ -143,6 +145,13 @@ func (c *chatCompletionRequest) getToolChoice() string {
143145
return c.ToolChoice
144146
}
145147

148+
func (c *chatCompletionRequest) getMaxCompletionTokens() *int64 {
149+
if c.MaxCompletionTokens != nil {
150+
return c.MaxCompletionTokens
151+
}
152+
return c.MaxTokens
153+
}
154+
146155
// getLastUserMsg returns last message from this request's messages with user role,
147156
// if does not exist - returns an empty string
148157
func (req *chatCompletionRequest) getLastUserMsg() string {
@@ -202,6 +211,10 @@ func (c *textCompletionRequest) getToolChoice() string {
202211
return ""
203212
}
204213

214+
func (c *textCompletionRequest) getMaxCompletionTokens() *int64 {
215+
return c.MaxTokens
216+
}
217+
205218
// createResponseText creates and returns response payload based on this request,
206219
// i.e., an array of generated tokens, the finish reason, and the number of created
207220
// tokens

pkg/llm-d-inference-sim/simulator.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ func (s *VllmSimulator) parseCommandParamsAndLoadConfig() error {
150150
f.IntVar(&config.MaxNumSeqs, "max-num-seqs", config.MaxNumSeqs, "Maximum number of inference requests that could be processed at the same time (parameter to simulate requests waiting queue)")
151151
f.IntVar(&config.MaxLoras, "max-loras", config.MaxLoras, "Maximum number of LoRAs in a single batch")
152152
f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory")
153+
f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output")
153154

154155
f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences")
155156
f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)")
@@ -372,6 +373,16 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
372373
return
373374
}
374375

376+
// Validate context window constraints
377+
promptTokens := vllmReq.getNumberOfPromptTokens()
378+
completionTokens := vllmReq.getMaxCompletionTokens()
379+
isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen)
380+
if !isValid {
381+
s.sendCompletionError(ctx, fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion",
382+
s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), "BadRequestError", fasthttp.StatusBadRequest)
383+
return
384+
}
385+
375386
var wg sync.WaitGroup
376387
wg.Add(1)
377388
reqCtx := &completionReqCtx{
@@ -530,7 +541,7 @@ func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string
530541
ctx.Error(err.Error(), fasthttp.StatusInternalServerError)
531542
} else {
532543
ctx.SetContentType("application/json")
533-
ctx.SetStatusCode(fasthttp.StatusNotFound)
544+
ctx.SetStatusCode(code)
534545
ctx.SetBody(data)
535546
}
536547
}

pkg/llm-d-inference-sim/simulator_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,4 +382,111 @@ var _ = Describe("Simulator", func() {
382382
Expect(err).NotTo(HaveOccurred())
383383
Expect(resp.StatusCode).To(Equal(http.StatusOK))
384384
})
385+
386+
Context("max-model-len context window validation", func() {
387+
It("Should reject requests exceeding context window", func() {
388+
ctx := context.TODO()
389+
// Start server with max-model-len=10
390+
args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-model-len", "10"}
391+
client, err := startServerWithArgs(ctx, modeRandom, args)
392+
Expect(err).NotTo(HaveOccurred())
393+
394+
// Test with raw HTTP to verify the error response format
395+
reqBody := `{
396+
"messages": [{"role": "user", "content": "This is a test message"}],
397+
"model": "my_model",
398+
"max_tokens": 8
399+
}`
400+
401+
resp, err := client.Post("http://localhost/v1/chat/completions", "application/json", strings.NewReader(reqBody))
402+
Expect(err).NotTo(HaveOccurred())
403+
defer func() {
404+
err := resp.Body.Close()
405+
Expect(err).NotTo(HaveOccurred())
406+
}()
407+
408+
body, err := io.ReadAll(resp.Body)
409+
Expect(err).NotTo(HaveOccurred())
410+
411+
Expect(resp.StatusCode).To(Equal(400))
412+
Expect(string(body)).To(ContainSubstring("This model's maximum context length is 10 tokens"))
413+
Expect(string(body)).To(ContainSubstring("However, you requested 13 tokens"))
414+
Expect(string(body)).To(ContainSubstring("5 in the messages, 8 in the completion"))
415+
Expect(string(body)).To(ContainSubstring("BadRequestError"))
416+
417+
// Also test with OpenAI client to ensure it gets an error
418+
openaiclient := openai.NewClient(
419+
option.WithBaseURL(baseURL),
420+
option.WithHTTPClient(client),
421+
)
422+
423+
_, err = openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
424+
Messages: []openai.ChatCompletionMessageParamUnion{
425+
openai.UserMessage("This is a test message"),
426+
},
427+
Model: model,
428+
MaxTokens: openai.Int(8),
429+
})
430+
431+
Expect(err).To(HaveOccurred())
432+
var apiErr *openai.Error
433+
Expect(errors.As(err, &apiErr)).To(BeTrue())
434+
Expect(apiErr.StatusCode).To(Equal(400))
435+
})
436+
437+
It("Should accept requests within context window", func() {
438+
ctx := context.TODO()
439+
// Start server with max-model-len=50
440+
args := []string{"cmd", "--model", model, "--mode", modeEcho, "--max-model-len", "50"}
441+
client, err := startServerWithArgs(ctx, modeEcho, args)
442+
Expect(err).NotTo(HaveOccurred())
443+
444+
openaiclient := openai.NewClient(
445+
option.WithBaseURL(baseURL),
446+
option.WithHTTPClient(client),
447+
)
448+
449+
// Send a request within the context window
450+
resp, err := openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
451+
Messages: []openai.ChatCompletionMessageParamUnion{
452+
openai.UserMessage("Hello"),
453+
},
454+
Model: model,
455+
MaxTokens: openai.Int(5),
456+
})
457+
458+
Expect(err).NotTo(HaveOccurred())
459+
Expect(resp.Choices).To(HaveLen(1))
460+
Expect(resp.Model).To(Equal(model))
461+
})
462+
463+
It("Should handle text completion requests exceeding context window", func() {
464+
ctx := context.TODO()
465+
// Start server with max-model-len=10
466+
args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-model-len", "10"}
467+
client, err := startServerWithArgs(ctx, modeRandom, args)
468+
Expect(err).NotTo(HaveOccurred())
469+
470+
// Test with raw HTTP for text completion
471+
reqBody := `{
472+
"prompt": "This is a long test prompt with many words",
473+
"model": "my_model",
474+
"max_tokens": 5
475+
}`
476+
477+
resp, err := client.Post("http://localhost/v1/completions", "application/json", strings.NewReader(reqBody))
478+
Expect(err).NotTo(HaveOccurred())
479+
defer func() {
480+
err := resp.Body.Close()
481+
Expect(err).NotTo(HaveOccurred())
482+
}()
483+
484+
body, err := io.ReadAll(resp.Body)
485+
Expect(err).NotTo(HaveOccurred())
486+
487+
Expect(resp.StatusCode).To(Equal(400))
488+
Expect(string(body)).To(ContainSubstring("This model's maximum context length is 10 tokens"))
489+
Expect(string(body)).To(ContainSubstring("BadRequestError"))
490+
})
491+
})
385492
})

pkg/llm-d-inference-sim/utils.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ func getMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error)
5858
return tokens, nil
5959
}
6060

61+
// validateContextWindow checks if the request fits within the model's context window
62+
// Returns validation result, actual completion tokens, and total tokens
63+
func validateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) (bool, int64, int64) {
64+
completionTokens := int64(0)
65+
if maxCompletionTokens != nil {
66+
completionTokens = *maxCompletionTokens
67+
}
68+
69+
totalTokens := int64(promptTokens) + completionTokens
70+
isValid := totalTokens <= int64(maxModelLen)
71+
72+
return isValid, completionTokens, totalTokens
73+
}
74+
6175
// getRandomResponseText returns random response text from the pre-defined list of responses
6276
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
6377
func getRandomResponseText(maxCompletionTokens *int64) (string, string) {

pkg/llm-d-inference-sim/utils_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,48 @@ var _ = Describe("Utils", func() {
4343
Expect(finishReason).Should(Equal(stopFinishReason))
4444
})
4545
})
46+
47+
Context("validateContextWindow", func() {
48+
It("should pass when total tokens are within limit", func() {
49+
promptTokens := 100
50+
maxCompletionTokens := int64(50)
51+
maxModelLen := 200
52+
53+
isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen)
54+
Expect(isValid).Should(BeTrue())
55+
Expect(actualCompletionTokens).Should(Equal(int64(50)))
56+
Expect(totalTokens).Should(Equal(int64(150)))
57+
})
58+
59+
It("should fail when total tokens exceed limit", func() {
60+
promptTokens := 150
61+
maxCompletionTokens := int64(100)
62+
maxModelLen := 200
63+
64+
isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen)
65+
Expect(isValid).Should(BeFalse())
66+
Expect(actualCompletionTokens).Should(Equal(int64(100)))
67+
Expect(totalTokens).Should(Equal(int64(250)))
68+
})
69+
70+
It("should handle nil max completion tokens", func() {
71+
promptTokens := 100
72+
maxModelLen := 200
73+
74+
isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, nil, maxModelLen)
75+
Expect(isValid).Should(BeTrue())
76+
Expect(actualCompletionTokens).Should(Equal(int64(0)))
77+
Expect(totalTokens).Should(Equal(int64(100)))
78+
})
79+
80+
It("should fail when only prompt tokens exceed limit", func() {
81+
promptTokens := 250
82+
maxModelLen := 200
83+
84+
isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, nil, maxModelLen)
85+
Expect(isValid).Should(BeFalse())
86+
Expect(actualCompletionTokens).Should(Equal(int64(0)))
87+
Expect(totalTokens).Should(Equal(int64(250)))
88+
})
89+
})
4690
})

0 commit comments

Comments
 (0)