From 72710e1e1e85d39e552e5326b9c2329c74b3e9ef Mon Sep 17 00:00:00 2001 From: mohitpalsingh Date: Sun, 13 Jul 2025 02:33:22 +0530 Subject: [PATCH 1/8] feat: add max-model-len configuration and validation for context window (#82) --- pkg/llm-d-inference-sim/config.go | 17 ++-- pkg/llm-d-inference-sim/config_test.go | 7 ++ pkg/llm-d-inference-sim/request.go | 13 +++ pkg/llm-d-inference-sim/simulator.go | 10 ++- pkg/llm-d-inference-sim/simulator_test.go | 101 ++++++++++++++++++++++ pkg/llm-d-inference-sim/utils.go | 21 +++++ pkg/llm-d-inference-sim/utils_test.go | 50 +++++++++++ 7 files changed, 213 insertions(+), 6 deletions(-) diff --git a/pkg/llm-d-inference-sim/config.go b/pkg/llm-d-inference-sim/config.go index d43e58e0..94bc0b71 100644 --- a/pkg/llm-d-inference-sim/config.go +++ b/pkg/llm-d-inference-sim/config.go @@ -41,6 +41,9 @@ type configuration struct { // MaxNumSeqs is maximum number of sequences per iteration (the maximum // number of inference requests that could be processed at the same time) MaxNumSeqs int `yaml:"max-num-seqs"` + // MaxModelLen is the model's context window, the maximum number of tokens + // in a single request including input and output. Default value is 1024. + MaxModelLen int `yaml:"max-model-len"` // LoraModulesString is a list of LoRA adapters as strings LoraModulesString []string `yaml:"lora-modules"` // LoraModules is a list of LoRA adapters @@ -97,11 +100,12 @@ func (c *configuration) unmarshalLoras() error { func newConfig() *configuration { return &configuration{ - Port: vLLMDefaultPort, - MaxLoras: 1, - MaxNumSeqs: 5, - Mode: modeRandom, - Seed: time.Now().UnixNano(), + Port: vLLMDefaultPort, + MaxLoras: 1, + MaxNumSeqs: 5, + MaxModelLen: 1024, + Mode: modeRandom, + Seed: time.Now().UnixNano(), } } @@ -151,6 +155,9 @@ func (c *configuration) validate() error { if c.MaxCPULoras < c.MaxLoras { return errors.New("max CPU LoRAs cannot be less than max LoRAs") } + if c.MaxModelLen < 1 { + return errors.New("max model len cannot be less than 1") + } for _, lora := range c.LoraModules { if lora.Name == "" { diff --git a/pkg/llm-d-inference-sim/config_test.go b/pkg/llm-d-inference-sim/config_test.go index 2ef1023d..68d6ee21 100644 --- a/pkg/llm-d-inference-sim/config_test.go +++ b/pkg/llm-d-inference-sim/config_test.go @@ -194,6 +194,12 @@ var _ = Describe("Simulator configuration", func() { } tests = append(tests, test) + test = testCase{ + name: "invalid max-model-len", + args: []string{"cmd", "--max-model-len", "0", "--config", "../../manifests/config.yaml"}, + } + tests = append(tests, test) + DescribeTable("check configurations", func(args []string, expectedConfig *configuration) { config, err := createSimConfig(args) @@ -217,5 +223,6 @@ var _ = Describe("Simulator configuration", func() { Entry(tests[7].name, tests[7].args), Entry(tests[8].name, tests[8].args), Entry(tests[9].name, tests[9].args), + Entry(tests[10].name, tests[10].args), ) }) diff --git a/pkg/llm-d-inference-sim/request.go b/pkg/llm-d-inference-sim/request.go index 9f0f55ac..4ebfecb3 100644 --- a/pkg/llm-d-inference-sim/request.go +++ b/pkg/llm-d-inference-sim/request.go @@ -42,6 +42,8 @@ type completionRequest interface { getTools() []tool // getToolChoice() returns tool choice (in chat completion) getToolChoice() string + // getMaxCompletionTokens returns the maximum completion tokens requested + getMaxCompletionTokens() *int64 } // baseCompletionRequest contains base completion request related information @@ -143,6 +145,13 @@ func (c *chatCompletionRequest) getToolChoice() string { return c.ToolChoice } +func (c *chatCompletionRequest) getMaxCompletionTokens() *int64 { + if c.MaxCompletionTokens != nil { + return c.MaxCompletionTokens + } + return c.MaxTokens +} + // getLastUserMsg returns last message from this request's messages with user role, // if does not exist - returns an empty string func (req *chatCompletionRequest) getLastUserMsg() string { @@ -202,6 +211,10 @@ func (c *textCompletionRequest) getToolChoice() string { return "" } +func (c *textCompletionRequest) getMaxCompletionTokens() *int64 { + return c.MaxTokens +} + // createResponseText creates and returns response payload based on this request, // i.e., an array of generated tokens, the finish reason, and the number of created // tokens diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 98dceea2..9c7bfb62 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -150,6 +150,7 @@ func (s *VllmSimulator) parseCommandParamsAndLoadConfig() error { f.IntVar(&config.MaxNumSeqs, "max-num-seqs", config.MaxNumSeqs, "Maximum number of inference requests that could be processed at the same time (parameter to simulate requests waiting queue)") f.IntVar(&config.MaxLoras, "max-loras", config.MaxLoras, "Maximum number of LoRAs in a single batch") f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory") + f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output") f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences") f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)") @@ -368,6 +369,13 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple return } + // Validate context window constraints + err = validateContextWindow(vllmReq.getNumberOfPromptTokens(), vllmReq.getMaxCompletionTokens(), s.config.MaxModelLen) + if err != nil { + s.sendCompletionError(ctx, err.Error(), "BadRequestError", fasthttp.StatusBadRequest) + return + } + var wg sync.WaitGroup wg.Add(1) reqCtx := &completionReqCtx{ @@ -527,7 +535,7 @@ func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string ctx.Error(err.Error(), fasthttp.StatusInternalServerError) } else { ctx.SetContentType("application/json") - ctx.SetStatusCode(fasthttp.StatusNotFound) + ctx.SetStatusCode(code) ctx.SetBody(data) } } diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 21d54e0f..169d1d79 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -382,4 +382,105 @@ var _ = Describe("Simulator", func() { Expect(err).NotTo(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) }) + + Context("max-model-len context window validation", func() { + It("Should reject requests exceeding context window", func() { + ctx := context.TODO() + // Start server with max-model-len=10 + args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-model-len", "10"} + client, err := startServerWithArgs(ctx, modeRandom, args) + Expect(err).NotTo(HaveOccurred()) + + // Test with raw HTTP to verify the error response format + reqBody := `{ + "messages": [{"role": "user", "content": "This is a test message"}], + "model": "my_model", + "max_tokens": 8 + }` + + resp, err := client.Post("http://localhost/v1/chat/completions", "application/json", strings.NewReader(reqBody)) + Expect(err).NotTo(HaveOccurred()) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + + Expect(resp.StatusCode).To(Equal(400)) + Expect(string(body)).To(ContainSubstring("This model's maximum context length is 10 tokens")) + Expect(string(body)).To(ContainSubstring("However, you requested 13 tokens")) + Expect(string(body)).To(ContainSubstring("5 in the messages, 8 in the completion")) + Expect(string(body)).To(ContainSubstring("BadRequestError")) + + // Also test with OpenAI client to ensure it gets an error + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err = openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("This is a test message"), + }, + Model: model, + MaxTokens: openai.Int(8), + }) + + Expect(err).To(HaveOccurred()) + var apiErr *openai.Error + Expect(errors.As(err, &apiErr)).To(BeTrue()) + Expect(apiErr.StatusCode).To(Equal(400)) + }) + + It("Should accept requests within context window", func() { + ctx := context.TODO() + // Start server with max-model-len=50 + args := []string{"cmd", "--model", model, "--mode", modeEcho, "--max-model-len", "50"} + client, err := startServerWithArgs(ctx, modeEcho, args) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + // Send a request within the context window + resp, err := openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Hello"), + }, + Model: model, + MaxTokens: openai.Int(5), + }) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Choices).To(HaveLen(1)) + Expect(resp.Model).To(Equal(model)) + }) + + It("Should handle text completion requests exceeding context window", func() { + ctx := context.TODO() + // Start server with max-model-len=10 + args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-model-len", "10"} + client, err := startServerWithArgs(ctx, modeRandom, args) + Expect(err).NotTo(HaveOccurred()) + + // Test with raw HTTP for text completion + reqBody := `{ + "prompt": "This is a long test prompt with many words", + "model": "my_model", + "max_tokens": 5 + }` + + resp, err := client.Post("http://localhost/v1/completions", "application/json", strings.NewReader(reqBody)) + Expect(err).NotTo(HaveOccurred()) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + + Expect(resp.StatusCode).To(Equal(400)) + Expect(string(body)).To(ContainSubstring("This model's maximum context length is 10 tokens")) + Expect(string(body)).To(ContainSubstring("BadRequestError")) + }) + }) }) diff --git a/pkg/llm-d-inference-sim/utils.go b/pkg/llm-d-inference-sim/utils.go index f4c0d4f2..9509145e 100644 --- a/pkg/llm-d-inference-sim/utils.go +++ b/pkg/llm-d-inference-sim/utils.go @@ -58,6 +58,27 @@ func getMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) return tokens, nil } +// validateContextWindow checks if the request fits within the model's context window +// Returns an error if prompt tokens + max completion tokens exceeds the max model length +func validateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) error { + if maxModelLen <= 0 { + return nil // no limit configured + } + + completionTokens := int64(0) + if maxCompletionTokens != nil { + completionTokens = *maxCompletionTokens + } + + totalTokens := int64(promptTokens) + completionTokens + if totalTokens > int64(maxModelLen) { + return fmt.Errorf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion.", + maxModelLen, totalTokens, promptTokens, completionTokens) + } + + return nil +} + // getRandomResponseText returns random response text from the pre-defined list of responses // considering max completion tokens if it is not nil, and a finish reason (stop or length) func getRandomResponseText(maxCompletionTokens *int64) (string, string) { diff --git a/pkg/llm-d-inference-sim/utils_test.go b/pkg/llm-d-inference-sim/utils_test.go index c24da458..ba015ae2 100644 --- a/pkg/llm-d-inference-sim/utils_test.go +++ b/pkg/llm-d-inference-sim/utils_test.go @@ -43,4 +43,54 @@ var _ = Describe("Utils", func() { Expect(finishReason).Should(Equal(stopFinishReason)) }) }) + + Context("validateContextWindow", func() { + It("should pass when total tokens are within limit", func() { + promptTokens := 100 + maxCompletionTokens := int64(50) + maxModelLen := 200 + + err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen) + Expect(err).Should(BeNil()) + }) + + It("should fail when total tokens exceed limit", func() { + promptTokens := 150 + maxCompletionTokens := int64(100) + maxModelLen := 200 + + err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen) + Expect(err).ShouldNot(BeNil()) + Expect(err.Error()).Should(ContainSubstring("This model's maximum context length is 200 tokens")) + Expect(err.Error()).Should(ContainSubstring("However, you requested 250 tokens")) + Expect(err.Error()).Should(ContainSubstring("150 in the messages, 100 in the completion")) + }) + + It("should pass when no max model length is configured", func() { + promptTokens := 1000 + maxCompletionTokens := int64(1000) + maxModelLen := 0 + + err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen) + Expect(err).Should(BeNil()) + }) + + It("should handle nil max completion tokens", func() { + promptTokens := 100 + maxModelLen := 200 + + err := validateContextWindow(promptTokens, nil, maxModelLen) + Expect(err).Should(BeNil()) + }) + + It("should fail when only prompt tokens exceed limit", func() { + promptTokens := 250 + maxModelLen := 200 + + err := validateContextWindow(promptTokens, nil, maxModelLen) + Expect(err).ShouldNot(BeNil()) + Expect(err.Error()).Should(ContainSubstring("This model's maximum context length is 200 tokens")) + Expect(err.Error()).Should(ContainSubstring("However, you requested 250 tokens")) + }) + }) }) From 39bb7a7062331f7b8b83ca38ea46a114653ac6ad Mon Sep 17 00:00:00 2001 From: mohitpalsingh Date: Sun, 13 Jul 2025 13:00:15 +0530 Subject: [PATCH 2/8] refactor: remove redundant check for max model length in validateContextWindow --- pkg/llm-d-inference-sim/utils.go | 4 ---- pkg/llm-d-inference-sim/utils_test.go | 9 --------- 2 files changed, 13 deletions(-) diff --git a/pkg/llm-d-inference-sim/utils.go b/pkg/llm-d-inference-sim/utils.go index 9509145e..a3fd0d9b 100644 --- a/pkg/llm-d-inference-sim/utils.go +++ b/pkg/llm-d-inference-sim/utils.go @@ -61,10 +61,6 @@ func getMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) // validateContextWindow checks if the request fits within the model's context window // Returns an error if prompt tokens + max completion tokens exceeds the max model length func validateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) error { - if maxModelLen <= 0 { - return nil // no limit configured - } - completionTokens := int64(0) if maxCompletionTokens != nil { completionTokens = *maxCompletionTokens diff --git a/pkg/llm-d-inference-sim/utils_test.go b/pkg/llm-d-inference-sim/utils_test.go index ba015ae2..4d131030 100644 --- a/pkg/llm-d-inference-sim/utils_test.go +++ b/pkg/llm-d-inference-sim/utils_test.go @@ -66,15 +66,6 @@ var _ = Describe("Utils", func() { Expect(err.Error()).Should(ContainSubstring("150 in the messages, 100 in the completion")) }) - It("should pass when no max model length is configured", func() { - promptTokens := 1000 - maxCompletionTokens := int64(1000) - maxModelLen := 0 - - err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen) - Expect(err).Should(BeNil()) - }) - It("should handle nil max completion tokens", func() { promptTokens := 100 maxModelLen := 200 From cba24c53825459c66fd697f58ce0d56a4c1ed1b7 Mon Sep 17 00:00:00 2001 From: mohitpalsingh Date: Sun, 13 Jul 2025 13:06:06 +0530 Subject: [PATCH 3/8] fix: correct indentation for test entry in simulator configuration tests --- pkg/llm-d-inference-sim/config_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/llm-d-inference-sim/config_test.go b/pkg/llm-d-inference-sim/config_test.go index 9e6fc977..0e73ac4f 100644 --- a/pkg/llm-d-inference-sim/config_test.go +++ b/pkg/llm-d-inference-sim/config_test.go @@ -264,6 +264,6 @@ var _ = Describe("Simulator configuration", func() { Entry(tests[8].name, tests[8].args), Entry(tests[9].name, tests[9].args), Entry(tests[10].name, tests[10].args), - Entry(tests[11].name, tests[11].args), + Entry(tests[11].name, tests[11].args), ) }) From ae380e0a78515309cbc2ddb8609984bd48c6995c Mon Sep 17 00:00:00 2001 From: mohitpalsingh Date: Sun, 13 Jul 2025 13:20:10 +0530 Subject: [PATCH 4/8] test: add additional test case for simulator configuration --- pkg/llm-d-inference-sim/config_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/llm-d-inference-sim/config_test.go b/pkg/llm-d-inference-sim/config_test.go index 0e73ac4f..336ac1c3 100644 --- a/pkg/llm-d-inference-sim/config_test.go +++ b/pkg/llm-d-inference-sim/config_test.go @@ -265,5 +265,6 @@ var _ = Describe("Simulator configuration", func() { Entry(tests[9].name, tests[9].args), Entry(tests[10].name, tests[10].args), Entry(tests[11].name, tests[11].args), + Entry(tests[12].name, tests[12].args), ) }) From 5bbb302f11ec3795274830e71521754a3f3486d2 Mon Sep 17 00:00:00 2001 From: mohitpalsingh Date: Sun, 13 Jul 2025 13:48:23 +0530 Subject: [PATCH 5/8] fix: static lint check errors --- pkg/llm-d-inference-sim/simulator_test.go | 10 ++++++++-- pkg/llm-d-inference-sim/utils.go | 2 +- pkg/llm-d-inference-sim/utils_test.go | 8 ++++---- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 169d1d79..22a507ae 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -400,7 +400,10 @@ var _ = Describe("Simulator", func() { resp, err := client.Post("http://localhost/v1/chat/completions", "application/json", strings.NewReader(reqBody)) Expect(err).NotTo(HaveOccurred()) - defer resp.Body.Close() + defer func() { + err := resp.Body.Close() + Expect(err).NotTo(HaveOccurred()) + }() body, err := io.ReadAll(resp.Body) Expect(err).NotTo(HaveOccurred()) @@ -473,7 +476,10 @@ var _ = Describe("Simulator", func() { resp, err := client.Post("http://localhost/v1/completions", "application/json", strings.NewReader(reqBody)) Expect(err).NotTo(HaveOccurred()) - defer resp.Body.Close() + defer func() { + err := resp.Body.Close() + Expect(err).NotTo(HaveOccurred()) + }() body, err := io.ReadAll(resp.Body) Expect(err).NotTo(HaveOccurred()) diff --git a/pkg/llm-d-inference-sim/utils.go b/pkg/llm-d-inference-sim/utils.go index a3fd0d9b..b3112b85 100644 --- a/pkg/llm-d-inference-sim/utils.go +++ b/pkg/llm-d-inference-sim/utils.go @@ -68,7 +68,7 @@ func validateContextWindow(promptTokens int, maxCompletionTokens *int64, maxMode totalTokens := int64(promptTokens) + completionTokens if totalTokens > int64(maxModelLen) { - return fmt.Errorf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion.", + return fmt.Errorf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", maxModelLen, totalTokens, promptTokens, completionTokens) } diff --git a/pkg/llm-d-inference-sim/utils_test.go b/pkg/llm-d-inference-sim/utils_test.go index 4d131030..d61f5b02 100644 --- a/pkg/llm-d-inference-sim/utils_test.go +++ b/pkg/llm-d-inference-sim/utils_test.go @@ -51,7 +51,7 @@ var _ = Describe("Utils", func() { maxModelLen := 200 err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen) - Expect(err).Should(BeNil()) + Expect(err).ShouldNot(HaveOccurred()) }) It("should fail when total tokens exceed limit", func() { @@ -60,7 +60,7 @@ var _ = Describe("Utils", func() { maxModelLen := 200 err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen) - Expect(err).ShouldNot(BeNil()) + Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("This model's maximum context length is 200 tokens")) Expect(err.Error()).Should(ContainSubstring("However, you requested 250 tokens")) Expect(err.Error()).Should(ContainSubstring("150 in the messages, 100 in the completion")) @@ -71,7 +71,7 @@ var _ = Describe("Utils", func() { maxModelLen := 200 err := validateContextWindow(promptTokens, nil, maxModelLen) - Expect(err).Should(BeNil()) + Expect(err).ShouldNot(HaveOccurred()) }) It("should fail when only prompt tokens exceed limit", func() { @@ -79,7 +79,7 @@ var _ = Describe("Utils", func() { maxModelLen := 200 err := validateContextWindow(promptTokens, nil, maxModelLen) - Expect(err).ShouldNot(BeNil()) + Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("This model's maximum context length is 200 tokens")) Expect(err.Error()).Should(ContainSubstring("However, you requested 250 tokens")) }) From 72ab4ba2c5f71ff0898a599022741c0b42c6c277 Mon Sep 17 00:00:00 2001 From: mohitpalsingh Date: Sun, 13 Jul 2025 13:53:24 +0530 Subject: [PATCH 6/8] fix: update error message capitalization in validateContextWindow --- pkg/llm-d-inference-sim/utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/llm-d-inference-sim/utils.go b/pkg/llm-d-inference-sim/utils.go index b3112b85..5b88a497 100644 --- a/pkg/llm-d-inference-sim/utils.go +++ b/pkg/llm-d-inference-sim/utils.go @@ -68,7 +68,7 @@ func validateContextWindow(promptTokens int, maxCompletionTokens *int64, maxMode totalTokens := int64(promptTokens) + completionTokens if totalTokens > int64(maxModelLen) { - return fmt.Errorf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", + return fmt.Errorf("this model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", maxModelLen, totalTokens, promptTokens, completionTokens) } From 70b08d5987d2fd7453b34f9d6ff62847b222ea9c Mon Sep 17 00:00:00 2001 From: mohitpalsingh Date: Sun, 13 Jul 2025 15:54:17 +0530 Subject: [PATCH 7/8] fix: update error message capitalization in validateContextWindow --- pkg/llm-d-inference-sim/utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/llm-d-inference-sim/utils.go b/pkg/llm-d-inference-sim/utils.go index 5b88a497..b3112b85 100644 --- a/pkg/llm-d-inference-sim/utils.go +++ b/pkg/llm-d-inference-sim/utils.go @@ -68,7 +68,7 @@ func validateContextWindow(promptTokens int, maxCompletionTokens *int64, maxMode totalTokens := int64(promptTokens) + completionTokens if totalTokens > int64(maxModelLen) { - return fmt.Errorf("this model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", + return fmt.Errorf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", maxModelLen, totalTokens, promptTokens, completionTokens) } From e1b93cd67c9833a320c0b01fb8afc4ca2526b8b1 Mon Sep 17 00:00:00 2001 From: mohitpalsingh Date: Mon, 14 Jul 2025 13:57:59 +0530 Subject: [PATCH 8/8] fix: refactored context window validation func with detailed error messages and update README --- README.md | 1 + pkg/llm-d-inference-sim/simulator.go | 9 ++++++--- pkg/llm-d-inference-sim/utils.go | 11 ++++------ pkg/llm-d-inference-sim/utils_test.go | 29 +++++++++++++++------------ 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index b992681e..5f4a18e8 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,7 @@ For more details see the = than max-loras, default is max-loras +- `max-model-len`: model's context window, maximum number of tokens in a single request including input and output, optional, default is 1024 - `max-num-seqs`: maximum number of sequences per iteration (maximum number of inference requests that could be processed at the same time), default is 5 - `mode`: the simulator mode, optional, by default `random` - `echo`: returns the same text that was sent in the request diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 22a67eb9..2f952f50 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -374,9 +374,12 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple } // Validate context window constraints - err = validateContextWindow(vllmReq.getNumberOfPromptTokens(), vllmReq.getMaxCompletionTokens(), s.config.MaxModelLen) - if err != nil { - s.sendCompletionError(ctx, err.Error(), "BadRequestError", fasthttp.StatusBadRequest) + promptTokens := vllmReq.getNumberOfPromptTokens() + completionTokens := vllmReq.getMaxCompletionTokens() + isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen) + if !isValid { + s.sendCompletionError(ctx, fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", + s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), "BadRequestError", fasthttp.StatusBadRequest) return } diff --git a/pkg/llm-d-inference-sim/utils.go b/pkg/llm-d-inference-sim/utils.go index b3112b85..170a0bdb 100644 --- a/pkg/llm-d-inference-sim/utils.go +++ b/pkg/llm-d-inference-sim/utils.go @@ -59,20 +59,17 @@ func getMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) } // validateContextWindow checks if the request fits within the model's context window -// Returns an error if prompt tokens + max completion tokens exceeds the max model length -func validateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) error { +// Returns validation result, actual completion tokens, and total tokens +func validateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) (bool, int64, int64) { completionTokens := int64(0) if maxCompletionTokens != nil { completionTokens = *maxCompletionTokens } totalTokens := int64(promptTokens) + completionTokens - if totalTokens > int64(maxModelLen) { - return fmt.Errorf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", - maxModelLen, totalTokens, promptTokens, completionTokens) - } + isValid := totalTokens <= int64(maxModelLen) - return nil + return isValid, completionTokens, totalTokens } // getRandomResponseText returns random response text from the pre-defined list of responses diff --git a/pkg/llm-d-inference-sim/utils_test.go b/pkg/llm-d-inference-sim/utils_test.go index d61f5b02..425c09a1 100644 --- a/pkg/llm-d-inference-sim/utils_test.go +++ b/pkg/llm-d-inference-sim/utils_test.go @@ -50,8 +50,10 @@ var _ = Describe("Utils", func() { maxCompletionTokens := int64(50) maxModelLen := 200 - err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen) - Expect(err).ShouldNot(HaveOccurred()) + isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen) + Expect(isValid).Should(BeTrue()) + Expect(actualCompletionTokens).Should(Equal(int64(50))) + Expect(totalTokens).Should(Equal(int64(150))) }) It("should fail when total tokens exceed limit", func() { @@ -59,29 +61,30 @@ var _ = Describe("Utils", func() { maxCompletionTokens := int64(100) maxModelLen := 200 - err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen) - Expect(err).Should(HaveOccurred()) - Expect(err.Error()).Should(ContainSubstring("This model's maximum context length is 200 tokens")) - Expect(err.Error()).Should(ContainSubstring("However, you requested 250 tokens")) - Expect(err.Error()).Should(ContainSubstring("150 in the messages, 100 in the completion")) + isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen) + Expect(isValid).Should(BeFalse()) + Expect(actualCompletionTokens).Should(Equal(int64(100))) + Expect(totalTokens).Should(Equal(int64(250))) }) It("should handle nil max completion tokens", func() { promptTokens := 100 maxModelLen := 200 - err := validateContextWindow(promptTokens, nil, maxModelLen) - Expect(err).ShouldNot(HaveOccurred()) + isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, nil, maxModelLen) + Expect(isValid).Should(BeTrue()) + Expect(actualCompletionTokens).Should(Equal(int64(0))) + Expect(totalTokens).Should(Equal(int64(100))) }) It("should fail when only prompt tokens exceed limit", func() { promptTokens := 250 maxModelLen := 200 - err := validateContextWindow(promptTokens, nil, maxModelLen) - Expect(err).Should(HaveOccurred()) - Expect(err.Error()).Should(ContainSubstring("This model's maximum context length is 200 tokens")) - Expect(err.Error()).Should(ContainSubstring("However, you requested 250 tokens")) + isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, nil, maxModelLen) + Expect(isValid).Should(BeFalse()) + Expect(actualCompletionTokens).Should(Equal(int64(0))) + Expect(totalTokens).Should(Equal(int64(250))) }) }) })