Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pkg/llm-d-inference-sim/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
package llmdinferencesim

import (
"strings"
"sync"

"github.com/valyala/fasthttp"
Expand Down Expand Up @@ -158,7 +157,7 @@ func (c *chatCompletionRequest) getNumberOfPromptTokens() int {
for _, message := range c.Messages {
messages += message.Content.PlainText() + " "
}
return len(strings.Fields(messages))
return len(tokenize(messages))
}

func (c *chatCompletionRequest) getTools() []tool {
Expand Down Expand Up @@ -224,7 +223,7 @@ type textCompletionRequest struct {
}

func (t *textCompletionRequest) getNumberOfPromptTokens() int {
return len(strings.Fields(t.Prompt))
return len(tokenize(t.Prompt))
}

func (c *textCompletionRequest) getTools() []tool {
Expand Down
3 changes: 2 additions & 1 deletion pkg/llm-d-inference-sim/seed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ var _ = Describe("Simulator with seed", func() {
Prompt: openai.CompletionNewParamsPromptUnion{
OfString: openai.String(userMessage),
},
Model: openai.CompletionNewParamsModel(model),
Model: openai.CompletionNewParamsModel(model),
MaxTokens: openai.Int(10),
}

resp, err := openaiclient.Completions.New(ctx, params)
Expand Down
4 changes: 4 additions & 0 deletions pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,10 @@ func (s *VllmSimulator) validateRequest(req completionRequest) (string, string,
return fmt.Sprintf("The model `%s` does not exist.", req.getModel()), "NotFoundError", fasthttp.StatusNotFound
}

if req.getMaxCompletionTokens() != nil && *req.getMaxCompletionTokens() <= 0 {
return "Max completion tokens and max tokens should be positive", "Invalid request", fasthttp.StatusBadRequest
}

if req.doRemoteDecode() && req.isStream() {
return "Prefill does not support streaming", "Invalid request", fasthttp.StatusBadRequest
}
Expand Down
59 changes: 35 additions & 24 deletions pkg/llm-d-inference-sim/simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ import (
const model = "my_model"
const baseURL = "http://localhost/v1"
const userMessage = "This is a test."
const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be positive"

var userMsgTokens int64

func startServer(ctx context.Context, mode string) (*http.Client, error) {
return startServerWithArgs(ctx, mode, nil)
Expand Down Expand Up @@ -65,6 +68,10 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http
return nil, err
}

// calculate number of tokens for user message,
// must be activated after parseCommandParamsAndLoadConfig since it initializes the random engine
userMsgTokens = int64(len(tokenize(userMessage)))

// run request processing workers
for i := 1; i <= s.config.MaxNumSeqs; i++ {
go s.reqProcessingWorker(ctx, i)
Expand Down Expand Up @@ -132,17 +139,19 @@ var _ = Describe("Simulator", func() {
}

Expect(numberOfChunksWithUsage).To(Equal(1))
Expect(chunk.Usage.PromptTokens).To(Equal(int64(4)))
Expect(chunk.Usage.PromptTokens).To(Equal(userMsgTokens))
Expect(chunk.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(chunk.Usage.TotalTokens).To(Equal(chunk.Usage.PromptTokens + chunk.Usage.CompletionTokens))

msg := strings.Join(tokens, "")
expectedMsg := userMessage
if mode == modeRandom {
expectedMsg = getFullTextFromPartialString(msg)
// in case of random mode ensure that the returned message could be output of the random text generator
Expect(isValidText(msg)).To(BeTrue())
} else {
// in case of echo mode check that the text is returned as-is
Expect(msg).Should(Equal(userMessage))
}
Expect(role).Should(Equal("assistant"))
Expect(msg).Should(Equal(expectedMsg))
},
func(mode string) string {
return "mode: " + mode
Expand Down Expand Up @@ -189,16 +198,18 @@ var _ = Describe("Simulator", func() {
Expect(string(chunk.Object)).To(Equal(textCompletionObject))
}
Expect(numberOfChunksWithUsage).To(Equal(1))
Expect(chunk.Usage.PromptTokens).To(Equal(int64(4)))
Expect(chunk.Usage.PromptTokens).To(Equal(userMsgTokens))
Expect(chunk.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(chunk.Usage.TotalTokens).To(Equal(chunk.Usage.PromptTokens + chunk.Usage.CompletionTokens))

text := strings.Join(tokens, "")
expectedText := userMessage
if mode == modeRandom {
expectedText = getFullTextFromPartialString(text)
// in case of random mode ensure that the returned message could be output of the random text generator
Expect(isValidText(text)).To(BeTrue())
} else {
// in case of echo mode check that the text is returned as-is
Expect(text).Should(Equal(userMessage))
}
Expect(text).Should(Equal(expectedText))
},
func(mode string) string {
return "mode: " + mode
Expand All @@ -224,18 +235,15 @@ var _ = Describe("Simulator", func() {
Model: model,
}
numTokens := 0
partialErrMsg := ""
// if maxTokens and maxCompletionTokens are passsed
// maxCompletionTokens is used
if maxTokens != 0 {
params.MaxTokens = param.NewOpt(int64(maxTokens))
numTokens = maxTokens
partialErrMsg = "max_tokens must be at least 1, got -1"
}
if maxCompletionTokens != 0 {
params.MaxCompletionTokens = param.NewOpt(int64(maxCompletionTokens))
numTokens = maxCompletionTokens
partialErrMsg = "max_completion_tokens must be at least 1, got -1"
}
resp, err := openaiclient.Chat.Completions.New(ctx, params)
if err != nil {
Expand All @@ -244,7 +252,7 @@ var _ = Describe("Simulator", func() {
if openaiError.StatusCode == 400 {
errMsg, err := io.ReadAll(openaiError.Response.Body)
Expect(err).NotTo(HaveOccurred())
if strings.Contains(string(errMsg), partialErrMsg) {
if strings.Contains(string(errMsg), invalidMaxTokensErrMsg) {
return
}
}
Expand All @@ -254,22 +262,24 @@ var _ = Describe("Simulator", func() {
Expect(resp.Choices).ShouldNot(BeEmpty())
Expect(string(resp.Object)).To(Equal(chatCompletionObject))

Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))

msg := resp.Choices[0].Message.Content
Expect(msg).ShouldNot(BeEmpty())

if numTokens > 0 {
tokens := strings.Fields(msg)
tokens := tokenize(msg)
Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens))
} else {
expectedMsg := userMessage
if mode == modeRandom {
expectedMsg = getFullTextFromPartialString(msg)
// in case of random mode ensure that the returned message could be output of the random text generator
Expect(isValidText(msg)).To(BeTrue())
} else {
// in case of echo mode check that the text is returned as-is
Expect(msg).Should(Equal(userMessage))
}
Expect(msg).Should(Equal(expectedMsg))
}
},
func(mode string, maxTokens int, maxCompletionTokens int) string {
Expand Down Expand Up @@ -310,7 +320,6 @@ var _ = Describe("Simulator", func() {
Model: openai.CompletionNewParamsModel(model),
}
numTokens := 0
partialErrMsg := "max_tokens must be at least 1, got -1"
if maxTokens != 0 {
params.MaxTokens = param.NewOpt(int64(maxTokens))
numTokens = maxTokens
Expand All @@ -322,7 +331,7 @@ var _ = Describe("Simulator", func() {
if openaiError.StatusCode == 400 {
errMsg, err := io.ReadAll(openaiError.Response.Body)
Expect(err).NotTo(HaveOccurred())
if strings.Contains(string(errMsg), partialErrMsg) {
if strings.Contains(string(errMsg), invalidMaxTokensErrMsg) {
return
}
}
Expand All @@ -332,22 +341,24 @@ var _ = Describe("Simulator", func() {
Expect(resp.Choices).ShouldNot(BeEmpty())
Expect(string(resp.Object)).To(Equal(textCompletionObject))

Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))

text := resp.Choices[0].Text
Expect(text).ShouldNot(BeEmpty())

if numTokens != 0 {
tokens := strings.Fields(text)
tokens := tokenize(text)
Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens))
} else {
expectedText := userMessage
if mode == modeRandom {
expectedText = getFullTextFromPartialString(text)
// in case of random mode ensure that the returned message could be output of the random text generator
Expect(isValidText(text)).To(BeTrue())
} else {
// in case of echo mode check that the text is returned as-is
Expect(text).Should(Equal(userMessage))
}
Expect(text).Should(Equal(expectedText))
}
},
func(mode string, maxTokens int) string {
Expand Down
12 changes: 6 additions & 6 deletions pkg/llm-d-inference-sim/tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ var _ = Describe("Simulator for request with tools", func() {
}

Expect(numberOfChunksWithUsage).To(Equal(1))
Expect(chunk.Usage.PromptTokens).To(Equal(int64(4)))
Expect(chunk.Usage.PromptTokens).To(Equal(userMsgTokens))
Expect(chunk.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(chunk.Usage.TotalTokens).To(Equal(chunk.Usage.PromptTokens + chunk.Usage.CompletionTokens))

Expand Down Expand Up @@ -451,7 +451,7 @@ var _ = Describe("Simulator for request with tools", func() {
Expect(resp.Choices).ShouldNot(BeEmpty())
Expect(string(resp.Object)).To(Equal(chatCompletionObject))

Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))

Expand Down Expand Up @@ -543,7 +543,7 @@ var _ = Describe("Simulator for request with tools", func() {
Expect(resp.Choices).ShouldNot(BeEmpty())
Expect(string(resp.Object)).To(Equal(chatCompletionObject))

Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))

Expand Down Expand Up @@ -599,7 +599,7 @@ var _ = Describe("Simulator for request with tools", func() {
Expect(resp.Choices).ShouldNot(BeEmpty())
Expect(string(resp.Object)).To(Equal(chatCompletionObject))

Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))

Expand Down Expand Up @@ -685,7 +685,7 @@ var _ = Describe("Simulator for request with tools", func() {
Expect(resp.Choices).ShouldNot(BeEmpty())
Expect(string(resp.Object)).To(Equal(chatCompletionObject))

Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))

Expand Down Expand Up @@ -747,7 +747,7 @@ var _ = Describe("Simulator for request with tools", func() {
Expect(resp.Choices).ShouldNot(BeEmpty())
Expect(string(resp.Object)).To(Equal(chatCompletionObject))

Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))

Expand Down
Loading