diff --git a/README.md b/README.md index c3f35c61..d113a315 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ For more details see the 0 { - <-simulator.runReqChan + for len(simulator.metrics.runReqChan) > 0 { + <-simulator.metrics.runReqChan } - simulator.runReqChan <- 1 + simulator.metrics.runReqChan <- 1 ttft := simulator.getWaitTimeToFirstToken(128, 0, false) Expect(ttft).To(Equal(42)) @@ -273,7 +275,7 @@ var _ = Describe("Check random latencies", Ordered, func() { simulator.config.TimeToFirstTokenStdDev = 0 simulator.config.TimeFactorUnderLoad = timeFactorUnderLoad simulator.config.MaxNumSeqs = maxNumOfReq - simulator.nRunningReqs = int64(maxNumOfReq) + simulator.metrics.nRunningReqs = int64(maxNumOfReq) ttft := simulator.getWaitTimeToFirstToken(128, 0, false) Expect(ttft).To(Equal(int(float64(42) * timeFactorUnderLoad))) @@ -296,7 +298,7 @@ var _ = Describe("Check random latencies", Ordered, func() { simulator.config.TimeToFirstTokenStdDev = 0 simulator.config.TimeFactorUnderLoad = timeFactorUnderLoad simulator.config.MaxNumSeqs = maxNumOfReq - simulator.nRunningReqs = int64(nCurrNumOfReq) + simulator.metrics.nRunningReqs = int64(nCurrNumOfReq) ttft := simulator.getWaitTimeToFirstToken(128, 0, false) max := timeFactorUnderLoad * float64(42) @@ -318,7 +320,7 @@ var _ = Describe("Check random latencies", Ordered, func() { It("when TimeFactorUnderLoad is 1.0, calcLoadFactor should give 1", func() { simulator.config.TimeFactorUnderLoad = 1.0 simulator.config.MaxNumSeqs = 11 - simulator.nRunningReqs = 3 + simulator.metrics.nRunningReqs = 3 factor := simulator.getCurrLoadFactor() Expect(factor).To(BeNumerically("==", 1.0)) @@ -327,7 +329,7 @@ var _ = Describe("Check random latencies", Ordered, func() { It("when TimeFactorUnderLoad is > 1.0, and sim is fully loaded, calcLoadFactor should give TimeFactorUnderLoad", func() { simulator.config.TimeFactorUnderLoad = 2.0 simulator.config.MaxNumSeqs = 11 - simulator.nRunningReqs = 11 + simulator.metrics.nRunningReqs = 11 factor := simulator.getCurrLoadFactor() Expect(factor).To(BeNumerically("==", simulator.config.TimeFactorUnderLoad)) @@ -337,7 +339,7 @@ var _ = Describe("Check random latencies", Ordered, func() { It("when TimeFactorUnderLoad is > 1.0, and sim is partially loaded, calcLoadFactor should give a value between 1 and TimeFactorUnderLoad", func() { simulator.config.TimeFactorUnderLoad = 2.0 simulator.config.MaxNumSeqs = 11 - simulator.nRunningReqs = 6 + simulator.metrics.nRunningReqs = 6 factor := simulator.getCurrLoadFactor() Expect(factor).To(BeNumerically(">", 1.0)) diff --git a/pkg/llm-d-inference-sim/lora.go b/pkg/llm-d-inference-sim/lora.go index 3102a66a..e5df5dc4 100644 --- a/pkg/llm-d-inference-sim/lora.go +++ b/pkg/llm-d-inference-sim/lora.go @@ -47,7 +47,7 @@ func (s *VllmSimulator) getLoras() []string { return loras } -func (s *VllmSimulator) loadLora(ctx *fasthttp.RequestCtx) { +func (s *VllmSimulator) loadLoraAdaptor(ctx *fasthttp.RequestCtx) { var req loadLoraRequest err := json.Unmarshal(ctx.Request.Body(), &req) if err != nil { @@ -59,7 +59,7 @@ func (s *VllmSimulator) loadLora(ctx *fasthttp.RequestCtx) { s.loraAdaptors.Store(req.LoraName, "") } -func (s *VllmSimulator) unloadLora(ctx *fasthttp.RequestCtx) { +func (s *VllmSimulator) unloadLoraAdaptor(ctx *fasthttp.RequestCtx) { var req unloadLoraRequest err := json.Unmarshal(ctx.Request.Body(), &req) if err != nil { @@ -70,3 +70,75 @@ func (s *VllmSimulator) unloadLora(ctx *fasthttp.RequestCtx) { s.loraAdaptors.Delete(req.LoraName) } + +// Checks if the LoRA adaptor is loaded +func (s *VllmSimulator) loraIsLoaded(model string) bool { + if !s.isLora(model) { + return true + } + + s.loras.mux.RLock() + defer s.loras.mux.RUnlock() + + _, ok := s.loras.loadedLoras[model] + return ok +} + +// Load the LoRA adaptor if possible. Return false if not. +func (s *VllmSimulator) loadLora(model string) bool { + if !s.isLora(model) { + return true + } + + s.loras.mux.Lock() + defer s.loras.mux.Unlock() + + // check if this LoRA is already loaded or within maxLoras slots + _, ok := s.loras.loadedLoras[model] + ok = ok || len(s.loras.loadedLoras) < s.loras.maxLoras + if !ok { + // if this LoRA is not loaded, and the number of loaded LoRAs reached + // maxLoras, try to find a LoRA that is not in use, and unload it + for lora, count := range s.loras.loadedLoras { + if count == 0 { + delete(s.loras.loadedLoras, lora) + ok = true + break + } + } + } + if ok { + s.loras.loadedLoras[model]++ + } + return ok +} + +// incrementLora increments the count of running requests using the model +// (if the model is a LoRA). Can be called only for loaded LoRAs (that are +// already in loras.loadedLoras) +func (s *VllmSimulator) incrementLora(model string) { + if !s.isLora(model) { + return + } + + s.loras.mux.Lock() + defer s.loras.mux.Unlock() + s.loras.loadedLoras[model]++ +} + +// decrementLora decrements the count of running requests using the model +// (if the model is a LoRA) +func (s *VllmSimulator) decrementLora(model string) { + if model == "" || !s.isLora(model) { + return + } + + s.loras.mux.Lock() + defer s.loras.mux.Unlock() + + s.loras.loadedLoras[model]-- + if s.loras.loadedLoras[model] <= 0 { + // last usage of this LoRA + s.loras.loraRemovable <- 1 + } +} diff --git a/pkg/llm-d-inference-sim/metrics.go b/pkg/llm-d-inference-sim/metrics.go index 5e785893..35108582 100644 --- a/pkg/llm-d-inference-sim/metrics.go +++ b/pkg/llm-d-inference-sim/metrics.go @@ -36,9 +36,9 @@ import ( // Metrics reported: // - lora_requests_info func (s *VllmSimulator) createAndRegisterPrometheus() error { - s.registry = prometheus.NewRegistry() + s.metrics.registry = prometheus.NewRegistry() - s.loraInfo = prometheus.NewGaugeVec( + s.metrics.loraInfo = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Subsystem: "", Name: "vllm:lora_requests_info", @@ -47,12 +47,12 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { []string{vllmapi.PromLabelMaxLora, vllmapi.PromLabelRunningLoraAdapters, vllmapi.PromLabelWaitingLoraAdapters}, ) - if err := s.registry.Register(s.loraInfo); err != nil { + if err := s.metrics.registry.Register(s.metrics.loraInfo); err != nil { s.logger.Error(err, "Prometheus lora info gauge register failed") return err } - s.runningRequests = prometheus.NewGaugeVec( + s.metrics.runningRequests = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Subsystem: "", Name: "vllm:num_requests_running", @@ -61,13 +61,13 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { []string{vllmapi.PromLabelModelName}, ) - if err := s.registry.Register(s.runningRequests); err != nil { + if err := s.metrics.registry.Register(s.metrics.runningRequests); err != nil { s.logger.Error(err, "Prometheus number of running requests gauge register failed") return err } // not supported for now, reports constant value - s.waitingRequests = prometheus.NewGaugeVec( + s.metrics.waitingRequests = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Subsystem: "", Name: "vllm:num_requests_waiting", @@ -76,12 +76,12 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { []string{vllmapi.PromLabelModelName}, ) - if err := s.registry.Register(s.waitingRequests); err != nil { + if err := s.metrics.registry.Register(s.metrics.waitingRequests); err != nil { s.logger.Error(err, "Prometheus number of requests in queue gauge register failed") return err } - s.ttft = prometheus.NewHistogramVec( + s.metrics.ttft = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: "", Name: "vllm:time_to_first_token_seconds", @@ -91,12 +91,12 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { []string{vllmapi.PromLabelModelName}, ) - if err := s.registry.Register(s.ttft); err != nil { + if err := s.metrics.registry.Register(s.metrics.ttft); err != nil { s.logger.Error(err, "Prometheus time to first token histogram register failed") return err } - s.tpot = prometheus.NewHistogramVec( + s.metrics.tpot = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: "", Name: "vllm:time_per_output_token_seconds", @@ -106,12 +106,12 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { []string{vllmapi.PromLabelModelName}, ) - if err := s.registry.Register(s.tpot); err != nil { + if err := s.metrics.registry.Register(s.metrics.tpot); err != nil { s.logger.Error(err, "Prometheus time per output token histogram register failed") return err } - s.kvCacheUsagePercentage = prometheus.NewGaugeVec( + s.metrics.kvCacheUsagePercentage = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Subsystem: "", Name: "vllm:gpu_cache_usage_perc", @@ -120,12 +120,12 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { []string{vllmapi.PromLabelModelName}, ) - if err := s.registry.Register(s.kvCacheUsagePercentage); err != nil { + if err := s.metrics.registry.Register(s.metrics.kvCacheUsagePercentage); err != nil { s.logger.Error(err, "Prometheus kv cache usage percentage gauge register failed") return err } - s.requestPromptTokens = prometheus.NewHistogramVec( + s.metrics.requestPromptTokens = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: "", Name: "vllm:request_prompt_tokens", @@ -134,12 +134,12 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { }, []string{vllmapi.PromLabelModelName}, ) - if err := s.registry.Register(s.requestPromptTokens); err != nil { + if err := s.metrics.registry.Register(s.metrics.requestPromptTokens); err != nil { s.logger.Error(err, "Prometheus request_prompt_tokens histogram register failed") return err } - s.requestGenerationTokens = prometheus.NewHistogramVec( + s.metrics.requestGenerationTokens = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: "", Name: "vllm:request_generation_tokens", @@ -148,12 +148,12 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { }, []string{vllmapi.PromLabelModelName}, ) - if err := s.registry.Register(s.requestGenerationTokens); err != nil { + if err := s.metrics.registry.Register(s.metrics.requestGenerationTokens); err != nil { s.logger.Error(err, "Prometheus request_generation_tokens histogram register failed") return err } - s.requestParamsMaxTokens = prometheus.NewHistogramVec( + s.metrics.requestParamsMaxTokens = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: "", Name: "vllm:request_params_max_tokens", @@ -162,12 +162,12 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { }, []string{vllmapi.PromLabelModelName}, ) - if err := s.registry.Register(s.requestParamsMaxTokens); err != nil { + if err := s.metrics.registry.Register(s.metrics.requestParamsMaxTokens); err != nil { s.logger.Error(err, "Prometheus request_params_max_tokens histogram register failed") return err } - s.requestSuccessTotal = prometheus.NewCounterVec( + s.metrics.requestSuccessTotal = prometheus.NewCounterVec( prometheus.CounterOpts{ Subsystem: "", Name: "vllm:request_success_total", @@ -175,7 +175,7 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error { }, []string{vllmapi.PromLabelModelName, vllmapi.PromLabelFinishReason}, ) - if err := s.registry.Register(s.requestSuccessTotal); err != nil { + if err := s.metrics.registry.Register(s.metrics.requestSuccessTotal); err != nil { s.logger.Error(err, "Prometheus request_success_total counter register failed") return err } @@ -195,41 +195,41 @@ func (s *VllmSimulator) setInitialPrometheusMetrics() { nWaitingReqs = float64(s.config.FakeMetrics.WaitingRequests) kvCacheUsage = float64(s.config.FakeMetrics.KVCacheUsagePercentage) if s.config.FakeMetrics.TTFTBucketValues != nil { - s.initFakeHistogram(s.ttft, common.TTFTBucketsBoundaries, s.config.FakeMetrics.TTFTBucketValues) + s.initFakeHistogram(s.metrics.ttft, common.TTFTBucketsBoundaries, s.config.FakeMetrics.TTFTBucketValues) } if s.config.FakeMetrics.TPOTBucketValues != nil { - s.initFakeHistogram(s.tpot, common.TPOTBucketsBoundaries, s.config.FakeMetrics.TPOTBucketValues) + s.initFakeHistogram(s.metrics.tpot, common.TPOTBucketsBoundaries, s.config.FakeMetrics.TPOTBucketValues) } buckets := build125Buckets(s.config.MaxModelLen) if s.config.FakeMetrics.RequestPromptTokens != nil { - s.initFakeHistogram(s.requestPromptTokens, buckets, s.config.FakeMetrics.RequestPromptTokens) + s.initFakeHistogram(s.metrics.requestPromptTokens, buckets, s.config.FakeMetrics.RequestPromptTokens) } if s.config.FakeMetrics.RequestGenerationTokens != nil { - s.initFakeHistogram(s.requestParamsMaxTokens, buckets, s.config.FakeMetrics.RequestGenerationTokens) + s.initFakeHistogram(s.metrics.requestParamsMaxTokens, buckets, s.config.FakeMetrics.RequestGenerationTokens) } if s.config.FakeMetrics.RequestParamsMaxTokens != nil { - s.initFakeHistogram(s.requestGenerationTokens, buckets, s.config.FakeMetrics.RequestParamsMaxTokens) + s.initFakeHistogram(s.metrics.requestGenerationTokens, buckets, s.config.FakeMetrics.RequestParamsMaxTokens) } for reason, requestSuccessTotal := range s.config.FakeMetrics.RequestSuccessTotal { - s.requestSuccessTotal.WithLabelValues(modelName, reason).Add(float64(requestSuccessTotal)) + s.metrics.requestSuccessTotal.WithLabelValues(modelName, reason).Add(float64(requestSuccessTotal)) } } - s.runningRequests.WithLabelValues(modelName).Set(nRunningReqs) - s.waitingRequests.WithLabelValues(modelName).Set(nWaitingReqs) - s.kvCacheUsagePercentage.WithLabelValues(modelName).Set(kvCacheUsage) + s.metrics.runningRequests.WithLabelValues(modelName).Set(nRunningReqs) + s.metrics.waitingRequests.WithLabelValues(modelName).Set(nWaitingReqs) + s.metrics.kvCacheUsagePercentage.WithLabelValues(modelName).Set(kvCacheUsage) if s.config.FakeMetrics != nil && len(s.config.FakeMetrics.LoraMetrics) != 0 { for _, metrics := range s.config.FakeMetrics.LoraMetrics { - s.loraInfo.WithLabelValues( + s.metrics.loraInfo.WithLabelValues( strconv.Itoa(s.config.MaxLoras), metrics.RunningLoras, metrics.WaitingLoras).Set(metrics.Timestamp) } } else { - s.loraInfo.WithLabelValues( + s.metrics.loraInfo.WithLabelValues( strconv.Itoa(s.config.MaxLoras), "", "").Set(float64(time.Now().Unix())) @@ -269,27 +269,27 @@ func (s *VllmSimulator) reportLoras() { if s.config.FakeMetrics != nil { return } - if s.loraInfo == nil { + if s.metrics.loraInfo == nil { // Happens in the tests return } var runningLoras []string - s.runningLoras.Range(func(key any, _ any) bool { + s.metrics.runningLoras.Range(func(key any, _ any) bool { if lora, ok := key.(string); ok { runningLoras = append(runningLoras, lora) } return true }) var waitingLoras []string - s.waitingLoras.Range(func(key any, _ any) bool { + s.metrics.waitingLoras.Range(func(key any, _ any) bool { if lora, ok := key.(string); ok { waitingLoras = append(waitingLoras, lora) } return true }) - s.loraInfo.WithLabelValues( + s.metrics.loraInfo.WithLabelValues( strconv.Itoa(s.config.MaxLoras), strings.Join(runningLoras, ","), strings.Join(waitingLoras, ",")).Set(float64(time.Now().Unix())) @@ -300,9 +300,9 @@ func (s *VllmSimulator) reportRunningRequests() { if s.config.FakeMetrics != nil { return } - if s.runningRequests != nil { - s.runningRequests.WithLabelValues( - s.getDisplayedModelName(s.config.Model)).Set(float64(s.nRunningReqs)) + if s.metrics.runningRequests != nil { + s.metrics.runningRequests.WithLabelValues( + s.getDisplayedModelName(s.config.Model)).Set(float64(s.metrics.nRunningReqs)) } } @@ -311,9 +311,9 @@ func (s *VllmSimulator) reportWaitingRequests() { if s.config.FakeMetrics != nil { return } - if s.waitingRequests != nil { - s.waitingRequests.WithLabelValues( - s.getDisplayedModelName(s.config.Model)).Set(float64(s.nWaitingReqs)) + if s.metrics.waitingRequests != nil { + s.metrics.waitingRequests.WithLabelValues( + s.getDisplayedModelName(s.config.Model)).Set(float64(s.metrics.nWaitingReqs)) } } @@ -322,8 +322,8 @@ func (s *VllmSimulator) reportTTFT(ttftInSecs float64) { if s.config.FakeMetrics != nil { return } - if s.ttft != nil { - s.ttft.WithLabelValues( + if s.metrics.ttft != nil { + s.metrics.ttft.WithLabelValues( s.getDisplayedModelName(s.config.Model)).Observe(ttftInSecs) } } @@ -333,8 +333,8 @@ func (s *VllmSimulator) reportTPOT(tpotInSecs float64) { if s.config.FakeMetrics != nil { return } - if s.tpot != nil { - s.tpot.WithLabelValues( + if s.metrics.tpot != nil { + s.metrics.tpot.WithLabelValues( s.getDisplayedModelName(s.config.Model)).Observe(tpotInSecs) } } @@ -344,8 +344,8 @@ func (s *VllmSimulator) reportKVCacheUsage(value float64) { if s.config.FakeMetrics != nil { return } - if s.kvCacheUsagePercentage != nil { - s.kvCacheUsagePercentage.WithLabelValues( + if s.metrics.kvCacheUsagePercentage != nil { + s.metrics.kvCacheUsagePercentage.WithLabelValues( s.getDisplayedModelName(s.config.Model)).Set(value) } } @@ -367,8 +367,8 @@ func (s *VllmSimulator) waitingRequestsUpdater(ctx context.Context) { select { case <-ctx.Done(): return - case inc := <-s.waitingReqChan: - s.nWaitingReqs += inc + case inc := <-s.metrics.waitingReqChan: + s.metrics.nWaitingReqs += inc s.reportWaitingRequests() } } @@ -380,8 +380,8 @@ func (s *VllmSimulator) runningRequestsUpdater(ctx context.Context) { select { case <-ctx.Done(): return - case inc := <-s.runReqChan: - s.nRunningReqs += inc + case inc := <-s.metrics.runReqChan: + s.metrics.nRunningReqs += inc s.reportRunningRequests() } } @@ -393,7 +393,7 @@ func (s *VllmSimulator) kvCacheUsageUpdater(ctx context.Context) { select { case <-ctx.Done(): return - case value := <-s.kvCacheUsageChan: + case value := <-s.metrics.kvCacheUsageChan: s.reportKVCacheUsage(value) } } @@ -405,7 +405,7 @@ func (s *VllmSimulator) ttftUpdater(ctx context.Context) { select { case <-ctx.Done(): return - case value := <-s.ttftChan: + case value := <-s.metrics.ttftChan: s.reportTTFT(value) } } @@ -417,7 +417,7 @@ func (s *VllmSimulator) tpotUpdater(ctx context.Context) { select { case <-ctx.Done(): return - case value := <-s.tpotChan: + case value := <-s.metrics.tpotChan: s.reportTPOT(value) } } @@ -430,15 +430,15 @@ func (s *VllmSimulator) lorasUpdater(ctx context.Context) { select { case <-ctx.Done(): return - case loraUpdate := <-s.lorasChan: + case loraUpdate := <-s.metrics.lorasChan: switch loraUpdate.state { case waitingUsageState: - s.incrementLoraRefCount(loraUpdate.name, &s.waitingLoras) + s.incrementLoraRefCount(loraUpdate.name, &s.metrics.waitingLoras) case runningUsageState: - s.decrementLoraRefCount(loraUpdate.name, &s.waitingLoras) - s.incrementLoraRefCount(loraUpdate.name, &s.runningLoras) + s.decrementLoraRefCount(loraUpdate.name, &s.metrics.waitingLoras) + s.incrementLoraRefCount(loraUpdate.name, &s.metrics.runningLoras) case doneUsageState: - s.decrementLoraRefCount(loraUpdate.name, &s.runningLoras) + s.decrementLoraRefCount(loraUpdate.name, &s.metrics.runningLoras) } s.reportLoras() } @@ -463,8 +463,6 @@ func (s *VllmSimulator) decrementLoraRefCount(lora string, theMap *sync.Map) { // last lora instance stopped its execution - remove from the map theMap.Delete(lora) } - } else { - s.logger.Error(nil, "Zero model reference", "model", lora) } } @@ -475,7 +473,7 @@ func (s *VllmSimulator) recordRequestUpdater(ctx context.Context) { select { case <-ctx.Done(): return - case event := <-s.requestSuccessChan: + case event := <-s.metrics.requestSuccessChan: s.recordRequestMetricsOnSuccess( event.promptTokens, event.generationTokens, @@ -503,12 +501,12 @@ type requestSuccessEvent struct { func (s *VllmSimulator) recordRequestMetricsOnSuccess(promptTokens, generationTokens int, maxTokens *int64, finishReason string) { modelName := s.getDisplayedModelName(s.config.Model) - s.requestPromptTokens.WithLabelValues(modelName).Observe(float64(promptTokens)) - s.requestGenerationTokens.WithLabelValues(modelName).Observe(float64(generationTokens)) + s.metrics.requestPromptTokens.WithLabelValues(modelName).Observe(float64(promptTokens)) + s.metrics.requestGenerationTokens.WithLabelValues(modelName).Observe(float64(generationTokens)) if maxTokens != nil { - s.requestParamsMaxTokens.WithLabelValues(modelName).Observe(float64(*maxTokens)) + s.metrics.requestParamsMaxTokens.WithLabelValues(modelName).Observe(float64(*maxTokens)) } - s.requestSuccessTotal.WithLabelValues(modelName, finishReason).Inc() + s.metrics.requestSuccessTotal.WithLabelValues(modelName, finishReason).Inc() } // build125Buckets generates histogram buckets in powers of 10 scaled by [1,2,5]. diff --git a/pkg/llm-d-inference-sim/metrics_test.go b/pkg/llm-d-inference-sim/metrics_test.go index 99ff4c3b..9f5b98f2 100644 --- a/pkg/llm-d-inference-sim/metrics_test.go +++ b/pkg/llm-d-inference-sim/metrics_test.go @@ -63,9 +63,29 @@ var paramsLora2 openai.ChatCompletionNewParams = openai.ChatCompletionNewParams{ Model: "lora2", } +var paramsLora3 openai.ChatCompletionNewParams = openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora3", +} + +var paramsLora4 openai.ChatCompletionNewParams = openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora4", +} + +var paramsLora5 openai.ChatCompletionNewParams = openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: "lora5", +} + var _ = Describe("Simulator metrics", Ordered, func() { It("Should send correct running and waiting requests metrics", func() { - modelName := "testmodel" // Three requests, only two can run in parallel, we expect // two running requests and one waiting request in the metrics ctx := context.TODO() @@ -77,9 +97,6 @@ var _ = Describe("Simulator metrics", Ordered, func() { openaiclient, params := getOpenAIClientAndChatParams(client, modelName, userMessage, false) - var wg sync.WaitGroup - wg.Add(1) - for range 3 { go func() { defer GinkgoRecover() @@ -88,23 +105,16 @@ var _ = Describe("Simulator metrics", Ordered, func() { }() } - go func() { - defer wg.Done() - defer GinkgoRecover() - - time.Sleep(300 * time.Millisecond) - metricsResp, err := client.Get(metricsUrl) - Expect(err).NotTo(HaveOccurred()) - Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) - - data, err := io.ReadAll(metricsResp.Body) - Expect(err).NotTo(HaveOccurred()) - metrics := string(data) - Expect(metrics).To(ContainSubstring("vllm:num_requests_running{model_name=\"testmodel\"} 2")) - Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"testmodel\"} 1")) - }() + time.Sleep(300 * time.Millisecond) + metricsResp, err := client.Get(metricsUrl) + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) - wg.Wait() + data, err := io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + Expect(metrics).To(ContainSubstring("vllm:num_requests_running{model_name=\"testmodel\"} 2")) + Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"testmodel\"} 1")) }) It("Should record correct prompt and generation token counts", func() { @@ -207,28 +217,20 @@ var _ = Describe("Simulator metrics", Ordered, func() { metrics := strings.Split(string(data), "\n") // We sent two sequentual requests to two different LoRAs, we expect to see (in this order) - // 1. running: empty, waiting: lora1 - // 2. running: lora1, waiting: empty - // 3. running: empty, waiting: lora2 - // 4. running: lora2, waiting: empty - // 5. running: empty, waiting: empty - Expect(isLoraMetricPresent(metrics, emptyArray, lora1Arr)).To(BeTrue()) + // 1. running: lora1, waiting: empty + // 2. running: lora2, waiting: empty + // 3. running: empty, waiting: empty Expect(isLoraMetricPresent(metrics, lora1Arr, emptyArray)).To(BeTrue()) - Expect(isLoraMetricPresent(metrics, emptyArray, lora2Arr)).To(BeTrue()) Expect(isLoraMetricPresent(metrics, lora2Arr, emptyArray)).To(BeTrue()) Expect(isLoraMetricPresent(metrics, emptyArray, emptyArray)).To(BeTrue()) // Check the order - timestamp1 := getLoraValidTimestamp(metrics, emptyArray, lora1Arr) - timestamp2 := getLoraValidTimestamp(metrics, lora1Arr, emptyArray) - timestamp3 := getLoraValidTimestamp(metrics, emptyArray, lora2Arr) - timestamp4 := getLoraValidTimestamp(metrics, lora2Arr, emptyArray) - timestamp5 := getLoraValidTimestamp(metrics, emptyArray, emptyArray) + timestamp1 := getLoraValidTimestamp(metrics, lora1Arr, emptyArray) + timestamp2 := getLoraValidTimestamp(metrics, lora2Arr, emptyArray) + timestamp3 := getLoraValidTimestamp(metrics, emptyArray, emptyArray) Expect(timestamp1 <= timestamp2).To(BeTrue()) Expect(timestamp2 <= timestamp3).To(BeTrue()) - Expect(timestamp3 <= timestamp4).To(BeTrue()) - Expect(timestamp4 <= timestamp5).To(BeTrue()) }) It("Should send correct lora metrics for parallel requests with delay", func() { @@ -246,12 +248,13 @@ var _ = Describe("Simulator metrics", Ordered, func() { option.WithHTTPClient(client)) var wg sync.WaitGroup - wg.Add(1) + wg.Add(2) // sends three requests with a delay of 0.5 second between them // request1 for lora1, request2 for lora2, and request 3 for lora1 go func() { time.Sleep(500 * time.Millisecond) + defer wg.Done() defer GinkgoRecover() _, err := openaiclient.Chat.Completions.New(ctx, paramsLora2) Expect(err).NotTo(HaveOccurred()) @@ -277,34 +280,32 @@ var _ = Describe("Simulator metrics", Ordered, func() { Expect(err).NotTo(HaveOccurred()) metrics := strings.Split(string(data), "\n") + // max_loras is 1 by default // We sent 3 requests, we expect to see (in this order) - // 1. running: empty, waiting: lora1 + // 1. running: lora1, waiting: empty // 2. running: lora1, waiting: lora2 - // 3. running: lora1, lora2 (in any order), waiting: lora1 - // 4. running: lora1, lora2 (in any order), waiting: empty - // 5. running: lora1, waiting: empty - // 6. running: empty, waiting: empty - Expect(isLoraMetricPresent(metrics, emptyArray, lora1Arr)).To(BeTrue()) - Expect(isLoraMetricPresent(metrics, lora1Arr, lora2Arr)).To(BeTrue()) - Expect(isLoraMetricPresent(metrics, []string{lora1, lora2}, lora1Arr)).To(BeTrue()) - Expect(isLoraMetricPresent(metrics, []string{lora1, lora2}, emptyArray)).To(BeTrue()) + // 3. running: empty, waiting: lora2 + // 4. running: lora2, waiting: empty + // 5. running: empty, waiting: empty + // (Requests 1 and 3 can run in parallel) Expect(isLoraMetricPresent(metrics, lora1Arr, emptyArray)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, lora1Arr, lora2Arr)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, emptyArray, lora2Arr)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, lora2Arr, emptyArray)).To(BeTrue()) Expect(isLoraMetricPresent(metrics, emptyArray, emptyArray)).To(BeTrue()) // Check the order - timestamp1 := getLoraValidTimestamp(metrics, emptyArray, lora1Arr) + timestamp1 := getLoraValidTimestamp(metrics, lora1Arr, emptyArray) timestamp2 := getLoraValidTimestamp(metrics, lora1Arr, lora2Arr) - timestamp3 := getLoraValidTimestamp(metrics, []string{lora1, lora2}, lora1Arr) - timestamp4 := getLoraValidTimestamp(metrics, []string{lora1, lora2}, emptyArray) - timestamp5 := getLoraValidTimestamp(metrics, lora1Arr, emptyArray) - timestamp6 := getLoraValidTimestamp(metrics, emptyArray, emptyArray) + timestamp3 := getLoraValidTimestamp(metrics, emptyArray, lora2Arr) + timestamp4 := getLoraValidTimestamp(metrics, lora2Arr, emptyArray) + timestamp5 := getLoraValidTimestamp(metrics, emptyArray, emptyArray) // in case of requests sent with delay the order is well-defined Expect(timestamp1 <= timestamp2).To(BeTrue()) Expect(timestamp2 <= timestamp3).To(BeTrue()) Expect(timestamp3 <= timestamp4).To(BeTrue()) Expect(timestamp4 <= timestamp5).To(BeTrue()) - Expect(timestamp5 <= timestamp6).To(BeTrue()) }) It("Should send correct lora metrics for parallel requests without delay", func() { @@ -347,36 +348,45 @@ var _ = Describe("Simulator metrics", Ordered, func() { // We sent two parallel requests: first to lora1 and then to lora2, // we expect to see metrics in this order: - // 1. running: empty, waiting: lora1 or lora2 (depends which request received first) - // 2. running: one of the loras, waiting: another lora - // 3. running: both lora2 and lora1 (the order of LoRAs doesn't matter here), waiting: empty + // 1. running: one of the loras, waiting: another lora + // 2. running: empty, waiting: another lora + // 3. running: the second lora, waiting: empty // 4. running: empty, waiting: empty - Expect(isLoraMetricPresent(metrics, emptyArray, lora1Arr) || isLoraMetricPresent(metrics, emptyArray, lora2Arr)).To(BeTrue()) Expect(isLoraMetricPresent(metrics, lora1Arr, lora2Arr) || isLoraMetricPresent(metrics, lora2Arr, lora1Arr)).To(BeTrue()) - Expect(isLoraMetricPresent(metrics, []string{lora1, lora2}, emptyArray)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, emptyArray, lora1Arr) || isLoraMetricPresent(metrics, emptyArray, lora2Arr)).To(BeTrue()) + Expect(isLoraMetricPresent(metrics, lora1Arr, emptyArray) || isLoraMetricPresent(metrics, lora2Arr, emptyArray)).To(BeTrue()) Expect(isLoraMetricPresent(metrics, emptyArray, emptyArray)).To(BeTrue()) // Check the order: - // 1. one of the loras in the waiting list - // 2. both loras in the running list - // 3. empty - l1WaitingTimestamp, err := getLoraTimestamp(metrics, emptyArray, lora1Arr) + l1RunningL2Waiting, err := getLoraTimestamp(metrics, lora1Arr, lora2Arr) Expect(err).NotTo(HaveOccurred()) - l2WaitingTimestamp, err := getLoraTimestamp(metrics, emptyArray, lora2Arr) + l2RunningL1Waiting, err := getLoraTimestamp(metrics, lora2Arr, lora1Arr) + Expect(err).NotTo(HaveOccurred()) + l1WatingEmptyRunning, err := getLoraTimestamp(metrics, emptyArray, lora1Arr) + Expect(err).NotTo(HaveOccurred()) + l2WatingEmptyRunning, err := getLoraTimestamp(metrics, emptyArray, lora2Arr) + Expect(err).NotTo(HaveOccurred()) + l1RunningEmptyWaiting, err := getLoraTimestamp(metrics, lora1Arr, emptyArray) + Expect(err).NotTo(HaveOccurred()) + l2RunningEmptyWaiting, err := getLoraTimestamp(metrics, lora2Arr, emptyArray) Expect(err).NotTo(HaveOccurred()) - Expect((l1WaitingTimestamp != nil)).ToNot(Equal((l2WaitingTimestamp != nil))) - var singleWaitingTimestamp float64 - if l1WaitingTimestamp != nil { - singleWaitingTimestamp = *l1WaitingTimestamp - } else { - singleWaitingTimestamp = *l2WaitingTimestamp - } - - bothRunningTimestamp := getLoraValidTimestamp(metrics, []string{lora1, lora2}, emptyArray) emptyTimestamp := getLoraValidTimestamp(metrics, emptyArray, emptyArray) - Expect(singleWaitingTimestamp <= bothRunningTimestamp).To(BeTrue()) - Expect(bothRunningTimestamp <= emptyTimestamp).To(BeTrue()) + if l1RunningL2Waiting != nil { + Expect(l2RunningL1Waiting).To(BeNil()) + Expect(l2WatingEmptyRunning).NotTo(BeNil()) + Expect(l2RunningEmptyWaiting).NotTo(BeNil()) + Expect(*l1RunningL2Waiting <= *l2WatingEmptyRunning).To(BeTrue()) + Expect(*l2WatingEmptyRunning <= *l2RunningEmptyWaiting).To(BeTrue()) + Expect(*l2RunningEmptyWaiting <= emptyTimestamp).To(BeTrue()) + } else { + Expect(l2RunningL1Waiting).NotTo(BeNil()) + Expect(l1WatingEmptyRunning).NotTo(BeNil()) + Expect(l1RunningEmptyWaiting).NotTo(BeNil()) + Expect(*l2RunningL1Waiting <= *l1WatingEmptyRunning).To(BeTrue()) + Expect(*l1WatingEmptyRunning <= *l1RunningEmptyWaiting).To(BeTrue()) + Expect(*l1RunningEmptyWaiting <= emptyTimestamp).To(BeTrue()) + } }) It("Should send correct ttft and tpot metrics", func() { @@ -817,12 +827,16 @@ func isLoraMetricPresent(metrics []string, running, waiting []string) bool { // getLoraTimestamp returns timestamp or nil, error func getLoraTimestamp(metrics []string, running, waiting []string) (*float64, error) { - mertic := findLoraMetric(metrics, running, waiting) - if mertic == "" { + metric := findLoraMetric(metrics, running, waiting) + if metric == "" { return nil, nil // not found } + return extractTimestamp(metric) +} + +func extractTimestamp(metric string) (*float64, error) { // Extract timestamp: last part after space - parts := strings.Split(mertic, " ") + parts := strings.Split(metric, " ") if len(parts) < 2 { return nil, errors.New("invalid metric format") } @@ -840,6 +854,28 @@ func getLoraValidTimestamp(metrics []string, running, waiting []string) float64 return *timestamp } +func getLastLoraMetrics(metrics []string) ([]string, error) { + lastTimestamp := float64(0) + var lastMetrics []string + for _, metric := range metrics { + if strings.HasPrefix(metric, "vllm:lora_requests_info") { + timestamp, err := extractTimestamp(metric) + if err != nil { + return nil, err + } + if lastTimestamp > *timestamp { + continue + } + lastTimestamp = *timestamp + if lastTimestamp < *timestamp { + lastMetrics = make([]string, 0) + } + lastMetrics = append(lastMetrics, metric) + } + } + return lastMetrics, nil +} + // findLoraMetric finds the relevant metric by comparing with the given loras sets (ignoring order) // metrics: lines of metrics // running: list of running loras to find diff --git a/pkg/llm-d-inference-sim/server.go b/pkg/llm-d-inference-sim/server.go index 6384f28d..6fba147d 100644 --- a/pkg/llm-d-inference-sim/server.go +++ b/pkg/llm-d-inference-sim/server.go @@ -54,7 +54,7 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) r.POST("/v1/load_lora_adapter", s.HandleLoadLora) r.POST("/v1/unload_lora_adapter", s.HandleUnloadLora) // supports /metrics prometheus API - r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.HandlerFor(s.registry, promhttp.HandlerOpts{}))) + r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.HandlerFor(s.metrics.registry, promhttp.HandlerOpts{}))) // supports standard Kubernetes health and readiness checks r.GET("/health", s.HandleHealth) r.GET("/ready", s.HandleReady) @@ -144,13 +144,11 @@ func (s *VllmSimulator) readRequest(ctx *fasthttp.RequestCtx, isChatCompletion b // HandleChatCompletions http handler for /v1/chat/completions func (s *VllmSimulator) HandleChatCompletions(ctx *fasthttp.RequestCtx) { - s.logger.Info("chat completion request received") s.handleCompletions(ctx, true) } // HandleTextCompletions http handler for /v1/completions func (s *VllmSimulator) HandleTextCompletions(ctx *fasthttp.RequestCtx) { - s.logger.Info("completion request received") s.handleCompletions(ctx, false) } @@ -209,12 +207,12 @@ func (s *VllmSimulator) HandleTokenize(ctx *fasthttp.RequestCtx) { func (s *VllmSimulator) HandleLoadLora(ctx *fasthttp.RequestCtx) { s.logger.Info("load lora request received") - s.loadLora(ctx) + s.loadLoraAdaptor(ctx) } func (s *VllmSimulator) HandleUnloadLora(ctx *fasthttp.RequestCtx) { s.logger.Info("unload lora request received") - s.unloadLora(ctx) + s.unloadLoraAdaptor(ctx) } func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (string, int) { diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 30a03148..0ed2b987 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -18,7 +18,9 @@ limitations under the License. package llmdinferencesim import ( + "container/list" "context" + "errors" "fmt" "os" "strings" @@ -49,8 +51,6 @@ const ( namespaceHeader = "x-inference-namespace" podNameEnv = "POD_NAME" podNsEnv = "POD_NAMESPACE" - - maxNumberOfRequests = 1000 ) type loraUsageState int @@ -68,14 +68,8 @@ type loraUsage struct { state loraUsageState } -// VllmSimulator simulates vLLM server supporting OpenAI API -type VllmSimulator struct { - // logger is used for information and errors logging - logger logr.Logger - // config is the simulator's configuration - config *common.Configuration - // loraAdaptors contains list of LoRA available adaptors - loraAdaptors sync.Map +// Prometheus metrics +type metricsData struct { // runningLoras is a collection of running loras, // the key is lora's name, the value is the number of running requests using this lora runningLoras sync.Map @@ -122,8 +116,32 @@ type VllmSimulator struct { requestParamsMaxTokens *prometheus.HistogramVec // requestSuccessTotal is prometheus counter for total number of successful requests requestSuccessTotal *prometheus.CounterVec - // channel for requeasts to be passed to workers - reqChan chan *openaiserverapi.CompletionReqCtx +} + +// LoRAs usage info for requests execution +type lorasUsageInfo struct { + mux sync.RWMutex + // lora adapter name -> reference count (number of currently running requests) + loadedLoras map[string]int + // channel for "there is a LoRA that can be removed" event + loraRemovable chan int + // maximum number of LoRAs that can be used simultaneously + maxLoras int +} + +type requestCompleted struct { + worker *worker + model string +} + +// VllmSimulator simulates vLLM server supporting OpenAI API +type VllmSimulator struct { + // logger is used for information and errors logging + logger logr.Logger + // config is the simulator's configuration + config *common.Configuration + // loraAdaptors contains list of LoRA available adaptors + loraAdaptors sync.Map // schema validator for tools parameters toolsValidator *openaiserverapi.Validator // kv cache functionality @@ -136,6 +154,23 @@ type VllmSimulator struct { tokenizer tokenization.Tokenizer // dataset is used for token generation in responses dataset dataset.Dataset + // metrics contains all Prometheus metrics related data + metrics metricsData + // loras contains information about which LoRAs are in use + loras *lorasUsageInfo + + // a channel for free workers + freeWorkers chan *worker + // a channel to indicate that a worker finished working on a request + workerFinished chan *requestCompleted + // waiting requests queue mutex + queueLock sync.Mutex + // bi-directional list of *openaiserverapi.CompletionReqCtx + waitingQueue *list.List + // the max capacity of the waiting requests queue + queueCapacity int + // a channel for incoming requests + newRequests chan *openaiserverapi.CompletionReqCtx } // New creates a new VllmSimulator instance with the given logger @@ -146,19 +181,15 @@ func New(logger logr.Logger) (*VllmSimulator, error) { } return &VllmSimulator{ - logger: logger, - reqChan: make(chan *openaiserverapi.CompletionReqCtx, maxNumberOfRequests), - toolsValidator: toolsValidator, - kvcacheHelper: nil, // kvcache helper will be created only if required after reading configuration - namespace: os.Getenv(podNsEnv), - pod: os.Getenv(podNameEnv), - runReqChan: make(chan int64, maxNumberOfRequests), - waitingReqChan: make(chan int64, maxNumberOfRequests), - ttftChan: make(chan float64, maxNumberOfRequests), - tpotChan: make(chan float64, maxNumberOfRequests), - lorasChan: make(chan loraUsage, maxNumberOfRequests), - kvCacheUsageChan: make(chan float64, maxNumberOfRequests), - requestSuccessChan: make(chan requestSuccessEvent, maxNumberOfRequests), + logger: logger, + toolsValidator: toolsValidator, + kvcacheHelper: nil, // kvcache helper will be created only if required after reading configuration + namespace: os.Getenv(podNsEnv), + pod: os.Getenv(podNameEnv), + loras: &lorasUsageInfo{ + loadedLoras: make(map[string]int), + }, + waitingQueue: list.New(), }, nil } @@ -207,11 +238,41 @@ func (s *VllmSimulator) Start(ctx context.Context) error { } func (s *VllmSimulator) startSim(ctx context.Context) error { + if err := s.initializeSim(ctx); err != nil { + return err + } + + listener, err := s.newListener() + if err != nil { + s.logger.Error(err, "Failed to create listener") + return fmt.Errorf("listener creation error: %w", err) + } + + // start the http server with context support + return s.startServer(ctx, listener) +} + +func (s *VllmSimulator) initializeSim(ctx context.Context) error { + common.InitRandom(s.config.Seed) + for _, lora := range s.config.LoraModules { s.loraAdaptors.Store(lora.Name, "") } + s.loras.maxLoras = s.config.MaxLoras + s.loras.loraRemovable = make(chan int, s.config.MaxNumSeqs) - common.InitRandom(s.config.Seed) + s.queueCapacity = s.config.MaxWaitingQueueLength + + maxNumberOfRequests := s.config.MaxNumSeqs + s.config.MaxWaitingQueueLength + s.metrics.runReqChan = make(chan int64, maxNumberOfRequests) + s.metrics.waitingReqChan = make(chan int64, maxNumberOfRequests) + s.metrics.lorasChan = make(chan loraUsage, maxNumberOfRequests) + s.metrics.kvCacheUsageChan = make(chan float64, maxNumberOfRequests) + s.metrics.ttftChan = make(chan float64, maxNumberOfRequests) + s.metrics.tpotChan = make(chan float64, maxNumberOfRequests) + s.metrics.requestSuccessChan = make(chan requestSuccessEvent, maxNumberOfRequests) + + s.newRequests = make(chan *openaiserverapi.CompletionReqCtx, maxNumberOfRequests) // initialize prometheus metrics err := s.createAndRegisterPrometheus() @@ -229,7 +290,7 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { } if s.config.EnableKVCache { - s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan, s.tokenizer) + s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.metrics.kvCacheUsageChan, s.tokenizer) if err != nil { return err } @@ -243,20 +304,25 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { } // run request processing workers + s.freeWorkers = make(chan *worker, s.config.MaxNumSeqs) + s.workerFinished = make(chan *requestCompleted, s.config.MaxNumSeqs) for i := 1; i <= s.config.MaxNumSeqs; i++ { - go s.reqProcessingWorker(ctx, i) + worker := &worker{ + id: i, + ctx: ctx, + logger: s.logger, + finishedChan: s.workerFinished, + reqChan: make(chan *openaiserverapi.CompletionReqCtx), + processor: s, + } + go worker.waitForRequests() + s.freeWorkers <- worker } s.startMetricsUpdaters(ctx) - listener, err := s.newListener() - if err != nil { - s.logger.Error(err, "Failed to create listener") - return fmt.Errorf("listener creation error: %w", err) - } - - // start the http server with context support - return s.startServer(ctx, listener) + go s.processing(ctx) + return nil } func (s *VllmSimulator) initDataset(ctx context.Context) error { @@ -294,6 +360,101 @@ func (s *VllmSimulator) Printf(format string, args ...interface{}) { s.logger.Info("Server error", "msg", fmt.Sprintf(format, args...)) } +func (s *VllmSimulator) processing(ctx context.Context) { + s.logger.Info("Start processing routine") + + for { + select { + case <-ctx.Done(): + s.logger.Info("Request processing done") + return + case completedReq := <-s.workerFinished: + s.logger.V(4).Info("Worker finished") + worker := completedReq.worker + s.decrementLora(completedReq.model) + // there is a free worker - find a request for it and send this request for + // processing with this worker + s.findRequestAndSendToProcess(worker) + case <-s.loras.loraRemovable: + // there is a LoRA that can be removed, go through availbale workers + // and queued requests and find requests that can run now, + // stop if there are no free workers, or no requests + s.logger.V(4).Info("LoRA can be removed") + for { + // check if there is a free worker + worker := s.getFreeWorker() + if worker == nil { + break + } + // check if there is a request that can run and send this request for + // processing with this worker + requestFound := s.findRequestAndSendToProcess(worker) + if !requestFound { + // there are no requests to run (either the queue is empty or maxLoras was reached) + break + } + } + case reqCtx := <-s.newRequests: + // A new request was received. Find a free worker, and check that the request can run LoRA wise. + model := reqCtx.CompletionReq.GetModel() + + worker := s.getFreeWorker() + if worker == nil { + s.logger.V(4).Info("No free worker - sending the request to the waiting queue", + "model", reqCtx.CompletionReq.GetModel(), "req id", reqCtx.CompletionReq.GetRequestID()) + // no free worker, add this request to the waiting queue + s.addRequestToQueue(reqCtx) + break + } + + // check if lora usage allows the request to run + if s.isLora(model) && !s.loadLora(model) { + // free the worker + s.freeWorkers <- worker + s.logger.V(4).Info("LoRA cannot be loaded - sending the request to the waiting queue", + "LoRA", model, "req id", reqCtx.CompletionReq.GetRequestID()) + // LoRA max reached, try to enqueue + s.addRequestToQueue(reqCtx) + break + } + + s.logger.V(4).Info("Sending the request to the processing channel", "model", model, + "req id", reqCtx.CompletionReq.GetRequestID(), "worker", worker.id) + worker.reqChan <- reqCtx + } + } +} + +func (s *VllmSimulator) findRequestAndSendToProcess(worker *worker) bool { + nextReq := s.dequeue() + if nextReq != nil { + // send this request for processing in this worker + s.logger.V(4).Info("Sending request to processing", "model", nextReq.CompletionReq.GetModel(), + "req", nextReq.CompletionReq.GetRequestID(), "worker", worker.id) + worker.reqChan <- nextReq + // decrement waiting requests metric + s.metrics.waitingReqChan <- -1 + return true + } + + // no waiting request, return worker to be free + s.freeWorkers <- worker + return false +} + +func (s *VllmSimulator) addRequestToQueue(reqCtx *openaiserverapi.CompletionReqCtx) { + if err := s.enqueue(reqCtx); err != nil { + s.logger.Error(err, "failed to enqueue request") + reqCtx.HTTPReqCtx.Error("Failed to enqueue request, "+err.Error(), fasthttp.StatusTooManyRequests) + reqCtx.Wg.Done() + return + } + // increment the waiting requests metric + s.metrics.waitingReqChan <- 1 + // update loraInfo metrics with the new waiting request + s.metrics.lorasChan <- loraUsage{reqCtx.CompletionReq.GetModel(), waitingUsageState} +} + // handleCompletions general completion requests handler, support both text and chat completion APIs func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { // Check if we should inject a failure @@ -316,6 +477,8 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple return } + s.logger.V(4).Info("Completion request received", "req id", vllmReq.GetRequestID(), "isChat", isChatCompletion) + var wg sync.WaitGroup wg.Add(1) reqCtx := &openaiserverapi.CompletionReqCtx{ @@ -324,130 +487,18 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple IsChatCompletion: isChatCompletion, Wg: &wg, } - // increment the waiting requests metric - s.waitingReqChan <- 1 - if s.isLora(reqCtx.CompletionReq.GetModel()) { - // update loraInfo metrics with the new waiting request - s.lorasChan <- loraUsage{reqCtx.CompletionReq.GetModel(), waitingUsageState} - } - // send the request to the waiting queue (channel) - s.reqChan <- reqCtx + s.newRequests <- reqCtx wg.Wait() } -func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { - for { - select { - case <-ctx.Done(): - s.logger.Info("reqProcessingWorker stopped:", "worker id", id) - return - case reqCtx, ok := <-s.reqChan: - if !ok { - s.logger.Info("reqProcessingWorker worker exiting: reqChan closed") - return - } - - req := reqCtx.CompletionReq - model := req.GetModel() - displayModel := s.getDisplayedModelName(model) - - // decrement waiting and increment running requests count - s.waitingReqChan <- -1 - s.runReqChan <- 1 - - if s.isLora(model) { - // update loraInfo metric to reflect that - // the request has changed its status from waiting to running - s.lorasChan <- loraUsage{model, runningUsageState} - } - - if s.config.EnableKVCache && !reqCtx.IsChatCompletion { - // kv cache is currently supported for /completion API only - if err := s.kvcacheHelper.OnRequestStart(req); err != nil { - s.sendCompletionError(reqCtx.HTTPReqCtx, openaiserverapi.NewCompletionError(err.Error(), fasthttp.StatusInternalServerError, nil), false) - } - } - - var responseTokens []string - var finishReason string - var err error - var toolCalls []openaiserverapi.ToolCall - var completionTokens int - if reqCtx.IsChatCompletion && - req.GetToolChoice() != openaiserverapi.ToolChoiceNone && - req.GetTools() != nil { - toolCalls, completionTokens, err = - openaiserverapi.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config) - finishReason = dataset.ToolsFinishReason - } - if toolCalls == nil && err == nil { - // Either no tool calls were defined, or we randomly chose not to create tool calls, - // so we generate a response text. - responseTokens, finishReason, err = s.dataset.GetTokens(req, s.config.Mode) - completionTokens += len(responseTokens) - } - if err != nil { - prefix := "" - if reqCtx.IsChatCompletion { - prefix = "failed to create chat response" - } else { - prefix = "failed to create text response" - } - s.logger.Error(err, prefix) - reqCtx.HTTPReqCtx.Error(prefix+err.Error(), fasthttp.StatusBadRequest) - } else { - usageData := openaiserverapi.Usage{ - PromptTokens: req.GetNumberOfPromptTokens(), - CompletionTokens: completionTokens, - TotalTokens: req.GetNumberOfPromptTokens() + completionTokens, - } - if req.IsStream() { - var usageDataToSend *openaiserverapi.Usage - if req.IncludeUsage() { - usageDataToSend = &usageData - } - s.sendStreamingResponse( - &streamingContext{ - ctx: reqCtx.HTTPReqCtx, - isChatCompletion: reqCtx.IsChatCompletion, - model: displayModel, - doRemotePrefill: req.IsDoRemotePrefill(), - nPromptTokens: usageData.PromptTokens, - nCachedPromptTokens: reqCtx.CompletionReq.GetNumberOfCachedPromptTokens(), - }, - responseTokens, toolCalls, finishReason, usageDataToSend, - ) - } else { - if req.IsDoRemoteDecode() { - // in case this is prefill pod processing, return special finish reason - finishReason = dataset.RemoteDecodeFinishReason - } - s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData) - } - select { - case s.requestSuccessChan <- requestSuccessEvent{ - promptTokens: usageData.PromptTokens, - generationTokens: usageData.CompletionTokens, - maxTokens: reqCtx.CompletionReq.GetMaxCompletionTokens(), - finishReason: finishReason, - }: - default: - s.logger.V(1).Info("requestSuccessChan full, dropping success event") - } - } - reqCtx.Wg.Done() - } - } -} - // request processing finished func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool, requestID string) { - // decriment running requests count - s.runReqChan <- -1 + // decrement running requests count + s.metrics.runReqChan <- -1 if s.isLora(model) { // update loraInfo metrics to reflect that the request processing has been finished - s.lorasChan <- loraUsage{model, doneUsageState} + s.metrics.lorasChan <- loraUsage{model, doneUsageState} } if s.config.EnableKVCache && !isChatCompletion { @@ -529,16 +580,15 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r time.Sleep(time.Duration(ttft) * time.Millisecond) // report ttft in seconds - s.ttftChan <- (float64(ttft) / 1000) + s.metrics.ttftChan <- (float64(ttft) / 1000) for range usageData.CompletionTokens - 1 { perTokenLatency := s.getInterTokenLatency() time.Sleep(time.Duration(perTokenLatency) * time.Millisecond) // report tpot in seconds - s.tpotChan <- float64(perTokenLatency) / 1000 + s.metrics.tpotChan <- float64(perTokenLatency) / 1000 } - s.sendCompletionResponse(reqCtx.HTTPReqCtx, resp) s.responseSentCallback(modelName, reqCtx.IsChatCompletion, reqCtx.CompletionReq.GetRequestID()) @@ -575,3 +625,41 @@ func (s *VllmSimulator) createModelsResponse() *vllmapi.ModelsResponse { return &modelsResp } + +func (s *VllmSimulator) enqueue(req *openaiserverapi.CompletionReqCtx) error { + s.queueLock.Lock() + defer s.queueLock.Unlock() + + if s.waitingQueue.Len() >= s.queueCapacity { + return errors.New("waiting requests queue is full") + } + s.waitingQueue.PushBack(req) + return nil +} + +// go though the queue and find the first request that can be executed, while taking into consideration the max lora limitation +func (s *VllmSimulator) dequeue() *openaiserverapi.CompletionReqCtx { + s.queueLock.Lock() + defer s.queueLock.Unlock() + + // Find first request for a loaded LoRA + for elem := s.waitingQueue.Front(); elem != nil; elem = elem.Next() { + req, ok := elem.Value.(*openaiserverapi.CompletionReqCtx) + if ok && req != nil && s.loraIsLoaded(req.CompletionReq.GetModel()) { + s.waitingQueue.Remove(elem) + s.incrementLora(req.CompletionReq.GetModel()) + return req + } + } + + // All the requests require a LoRA that is not loaded, check if we can load a LoRA + for elem := s.waitingQueue.Front(); elem != nil; elem = elem.Next() { + req, ok := elem.Value.(*openaiserverapi.CompletionReqCtx) + if ok && req != nil && s.loadLora(req.CompletionReq.GetModel()) { + s.waitingQueue.Remove(elem) + return req + } + } + + return nil +} diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index b9b276e6..5a5583b0 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -29,8 +29,6 @@ import ( "github.com/llm-d/llm-d-inference-sim/pkg/common" "github.com/llm-d/llm-d-inference-sim/pkg/dataset" - kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" - "github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/openai/openai-go/v3" @@ -90,50 +88,14 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m } s.config = config - for _, lora := range config.LoraModules { - s.loraAdaptors.Store(lora.Name, "") - } - - common.InitRandom(s.config.Seed) - - if err := s.createAndRegisterPrometheus(); err != nil { - return nil, err - } - - tokenizationConfig := tokenization.DefaultConfig() - if s.config.TokenizersCacheDir != "" { - tokenizationConfig.TokenizersCacheDir = s.config.TokenizersCacheDir - } - s.tokenizer, err = tokenization.NewCachedHFTokenizer(tokenizationConfig.HFTokenizerConfig) - if err != nil { - return nil, fmt.Errorf("failed to create tokenizer: %w", err) - } - - if s.config.EnableKVCache { - s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, s.kvCacheUsageChan, s.tokenizer) - if err != nil { - return nil, err - } - - go s.kvcacheHelper.Run(ctx) - } - - err = s.initDataset(ctx) - if err != nil { - return nil, fmt.Errorf("dataset initialization error: %w", err) - } - // calculate number of tokens for user message, // must be activated after parseCommandParamsAndLoadConfig since it initializes the random engine userMsgTokens = int64(len(common.Tokenize(userMessage))) - // run request processing workers - for i := 1; i <= s.config.MaxNumSeqs; i++ { - go s.reqProcessingWorker(ctx, i) + if err := s.initializeSim(ctx); err != nil { + return nil, err } - s.startMetricsUpdaters(ctx) - listener := fasthttputil.NewInmemoryListener() // start the http server diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index 1bd2525d..ee02b4b1 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -104,14 +104,14 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ ttft := s.getWaitTimeToFirstToken(context.nPromptTokens, context.nCachedPromptTokens, context.doRemotePrefill) time.Sleep(time.Duration(ttft) * time.Millisecond) // report ttft in seconds - s.ttftChan <- (float64(ttft) / 1000) + s.metrics.ttftChan <- (float64(ttft) / 1000) for i, token := range genTokens { if i != 0 { interTokenLat := s.getInterTokenLatency() time.Sleep(time.Duration(interTokenLat) * time.Millisecond) // report tpot in seconds - s.tpotChan <- float64(interTokenLat) / 1000 + s.metrics.tpotChan <- float64(interTokenLat) / 1000 } var toolChunkInsert *openaiserverapi.ToolCall diff --git a/pkg/llm-d-inference-sim/worker.go b/pkg/llm-d-inference-sim/worker.go new file mode 100644 index 00000000..5171580a --- /dev/null +++ b/pkg/llm-d-inference-sim/worker.go @@ -0,0 +1,165 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package vllmsim implements the vLLM simulator. +package llmdinferencesim + +import ( + "context" + + "github.com/go-logr/logr" + "github.com/llm-d/llm-d-inference-sim/pkg/dataset" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" + "github.com/valyala/fasthttp" +) + +// worker runs simulators requests +type worker struct { + ctx context.Context + logger logr.Logger + // worker's id + id int + // a channel for requests + reqChan chan *openaiserverapi.CompletionReqCtx + // a channel to indicate that the worker finished processing a request + finishedChan chan *requestCompleted + // the request processor + processor requestProcessor +} + +func (w *worker) waitForRequests() { + for { + select { + case <-w.ctx.Done(): + w.logger.V(4).Info("worker done", "id", w.id) + return + case req := <-w.reqChan: + w.processor.processRequest(req) + w.finishedChan <- &requestCompleted{worker: w, model: req.CompletionReq.GetModel()} + } + } +} + +type requestProcessor interface { + processRequest(reqCtx *openaiserverapi.CompletionReqCtx) +} + +func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx) { + req := reqCtx.CompletionReq + model := req.GetModel() + displayModel := s.getDisplayedModelName(model) + + // increment running requests count + s.metrics.runReqChan <- 1 + + if s.isLora(model) { + // update loraInfo metric to reflect that + // the request has changed its status from waiting to running + s.metrics.lorasChan <- loraUsage{model, runningUsageState} + } + + if s.config.EnableKVCache && !reqCtx.IsChatCompletion { + // kv cache is currently supported for /completion API only + if err := s.kvcacheHelper.OnRequestStart(req); err != nil { + s.sendCompletionError(reqCtx.HTTPReqCtx, + openaiserverapi.NewCompletionError(err.Error(), fasthttp.StatusInternalServerError, nil), + false) + } + } + + var responseTokens []string + var finishReason string + var err error + var toolCalls []openaiserverapi.ToolCall + var completionTokens int + if reqCtx.IsChatCompletion && + req.GetToolChoice() != openaiserverapi.ToolChoiceNone && + req.GetTools() != nil { + toolCalls, completionTokens, err = + openaiserverapi.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config) + finishReason = dataset.ToolsFinishReason + } + if toolCalls == nil && err == nil { + // Either no tool calls were defined, or we randomly chose not to create tool calls, + // so we generate a response text. + responseTokens, finishReason, err = s.dataset.GetTokens(req, s.config.Mode) + completionTokens += len(responseTokens) + } + if err != nil { + prefix := "" + if reqCtx.IsChatCompletion { + prefix = "failed to create chat response" + } else { + prefix = "failed to create text response" + } + s.logger.Error(err, prefix) + reqCtx.HTTPReqCtx.Error(prefix+err.Error(), fasthttp.StatusBadRequest) + } else { + usageData := openaiserverapi.Usage{ + PromptTokens: req.GetNumberOfPromptTokens(), + CompletionTokens: completionTokens, + TotalTokens: req.GetNumberOfPromptTokens() + completionTokens, + } + if req.IsStream() { + var usageDataToSend *openaiserverapi.Usage + if req.IncludeUsage() { + usageDataToSend = &usageData + } + s.sendStreamingResponse( + &streamingContext{ + ctx: reqCtx.HTTPReqCtx, + isChatCompletion: reqCtx.IsChatCompletion, + model: displayModel, + doRemotePrefill: req.IsDoRemotePrefill(), + nPromptTokens: usageData.PromptTokens, + nCachedPromptTokens: reqCtx.CompletionReq.GetNumberOfCachedPromptTokens(), + }, + responseTokens, toolCalls, finishReason, usageDataToSend, + ) + } else { + if req.IsDoRemoteDecode() { + // in case this is prefill pod processing, return special finish reason + finishReason = dataset.RemoteDecodeFinishReason + } + s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData) + } + + select { + case s.metrics.requestSuccessChan <- requestSuccessEvent{ + promptTokens: usageData.PromptTokens, + generationTokens: usageData.CompletionTokens, + maxTokens: reqCtx.CompletionReq.GetMaxCompletionTokens(), + finishReason: finishReason, + }: + default: + s.logger.V(1).Info("requestSuccessChan full, dropping success event") + } + } + + s.logger.V(4).Info("Finished processing request", "id", req.GetRequestID()) + + reqCtx.Wg.Done() +} + +// getFreeWorker returns a free worker or nil if none are available (non-blocking) +func (s *VllmSimulator) getFreeWorker() *worker { + select { + case w := <-s.freeWorkers: + return w + default: + return nil + } +} diff --git a/pkg/llm-d-inference-sim/worker_test.go b/pkg/llm-d-inference-sim/worker_test.go new file mode 100644 index 00000000..8d2c612d --- /dev/null +++ b/pkg/llm-d-inference-sim/worker_test.go @@ -0,0 +1,586 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "context" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" +) + +const modelName = "testmodel" + +var _ = Describe("Simulator requests scheduling", Ordered, func() { + Context("Requests for already loaded loras should be handled first", func() { + DescribeTable("Should process in correct order simultaneous requests to two loras", func(maxNumSeq string) { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeEcho, + "--time-to-first-token", "500", "--max-num-seqs", maxNumSeq, + "--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} + + client, err := startServerWithArgs(ctx, common.ModeEcho, args, nil) + Expect(err).NotTo(HaveOccurred()) + openaiclient := openai.NewClient(option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + numberOfRequests := 4 + orderOfResponses := make([]int, 0) + var wg sync.WaitGroup + wg.Add(numberOfRequests) + var mux sync.RWMutex + + // Send simultaneously half of the requests to lora1 and the second half to lora2 + for reqNum := range numberOfRequests { + params := paramsLora2 + if reqNum%2 == 0 { + params = paramsLora1 + } + go sendReq(ctx, openaiclient, &wg, 0, params, reqNum, &mux, &orderOfResponses) + } + wg.Wait() + + // Check the order in which the requests are handled: + // if the first handled request is even, all the first half of the requests should + // be even (because they all use the same lora that is already loaded). + firstReqIsEven := orderOfResponses[0]%2 == 0 + for i, reqNum := range orderOfResponses { + if i < numberOfRequests/2 { + // nolint + Expect(reqNum%2 == 0).To(Equal(firstReqIsEven)) + } else { + // nolint + Expect(reqNum%2 == 0).NotTo(Equal(firstReqIsEven)) + } + } + }, + Entry("5 workers", "5"), + Entry("1 worker", "1"), + ) + + DescribeTable("Should process in correct order delayed requests to two loras", + func(maxNumSeq string, maxLoras string, checkOrder func([]int)) { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeEcho, + "--time-to-first-token", "1000", + "--max-num-seqs", maxNumSeq, "--max-loras", maxLoras, + "--lora-modules", "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} + + client, err := startServerWithArgs(ctx, common.ModeEcho, args, nil) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient(option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + numberOfRequests := 8 + orderOfResponses := make([]int, 0) + var wg sync.WaitGroup + wg.Add(numberOfRequests) + var mux sync.RWMutex + + // Send three requests to lora1, after 100 milliseconds four requests to lora2, + // and after 400 milliseconds a request to lora1 again. + for reqNum := range 3 { + go sendReq(ctx, openaiclient, &wg, 0, paramsLora1, reqNum, &mux, &orderOfResponses) + } + for reqNum := 4; reqNum < 8; reqNum++ { + go sendReq(ctx, openaiclient, &wg, 100, paramsLora2, reqNum, &mux, &orderOfResponses) + } + go sendReq(ctx, openaiclient, &wg, 500, paramsLora1, 3, &mux, &orderOfResponses) + + wg.Wait() + + // Check the order in which the requests are handled + checkOrder(orderOfResponses) + }, + Entry("5 workers, max loras 1", "5", "1", checkOrder), + Entry("1 worker, max loras 5", "1", "5", checkOrder), + Entry("2 workers, max loras 1", "2", "1", checkOrderMaxLora1Workers2), + Entry("5 workers, max loras 5", "5", "5", checkOrderMaxLora5Workers5), + ) + + It("Should keep the order of requests with one worker", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeEcho, + "--time-to-first-token", "500", + "--max-num-seqs", "1", "--max-loras", "1", + "--lora-modules", + "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", + "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", + "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} + + client, err := startServerWithArgs(ctx, common.ModeEcho, args, nil) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient(option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + numberOfRequests := 9 + orderOfResponses := make([]int, 0) + var wg sync.WaitGroup + wg.Add(numberOfRequests) + var mux sync.RWMutex + + // The order of the requests is: + // 0-lora1 1-lora1 2-lora2 3-lora3 4-lora4 5-lora1 6-lora2 7-lora3 8-lora4 + go sendReq(ctx, openaiclient, &wg, 0, paramsLora1, 0, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 50, paramsLora1, 1, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 100, paramsLora2, 2, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 200, paramsLora3, 3, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 300, paramsLora4, 4, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 400, paramsLora1, 5, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 500, paramsLora2, 6, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 600, paramsLora3, 7, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 700, paramsLora4, 8, &mux, &orderOfResponses) + wg.Wait() + + // Check the order in which the requests are handled + checkOrderMaxLora1Workers1(orderOfResponses) + }) + + It("Should keep the order of requests with two workers", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeEcho, + "--time-to-first-token", "500", + "--max-num-seqs", "2", "--max-loras", "1", + "--lora-modules", + "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", + "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", + "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} + + client, err := startServerWithArgs(ctx, common.ModeEcho, args, nil) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient(option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + numberOfRequests := 8 + orderOfResponses := make([]int, 0) + var wg sync.WaitGroup + wg.Add(numberOfRequests) + var mux sync.RWMutex + + // The order of the requests is: + // 0-lora1 1-lora1 2-lora2 3-lora3 4-lora4 5-lora2 6-lora3 7-lora4 + go sendReq(ctx, openaiclient, &wg, 0, paramsLora1, 0, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 0, paramsLora1, 1, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 100, paramsLora2, 2, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 200, paramsLora3, 3, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 300, paramsLora4, 4, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 400, paramsLora2, 5, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 500, paramsLora3, 6, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 600, paramsLora4, 7, &mux, &orderOfResponses) + wg.Wait() + + // Check the order in which the requests are handled + checkOrderWorkers2(orderOfResponses) + }) + + DescribeTable("Should keep the order of requests with multiple workers and loras", + func(maxNumSeq string, maxLoras string, checkOrder func([]int)) { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", common.ModeEcho, + "--time-to-first-token", "1000", + "--max-num-seqs", maxNumSeq, "--max-loras", maxLoras, + "--lora-modules", + "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", + "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", + "{\"name\":\"lora5\",\"path\":\"/path/to/lora5\"}", + "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}"} + + client, err := startServerWithArgs(ctx, common.ModeEcho, args, nil) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient(option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + numberOfRequests := 11 + orderOfResponses := make([]int, 0) + var wg sync.WaitGroup + wg.Add(numberOfRequests) + var mux sync.RWMutex + + // The order of the requests is: + // 0-lora1 1-lora1 2-lora2 3-lora3 4-lora4 5-lora5 + // 6-lora1 7-lora2 8-lora3 9-lora4 10-lora5 + go sendReq(ctx, openaiclient, &wg, 0, paramsLora1, 0, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 100, paramsLora1, 1, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 200, paramsLora2, 2, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 300, paramsLora3, 3, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 400, paramsLora4, 4, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 500, paramsLora5, 5, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 600, paramsLora1, 6, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 700, paramsLora2, 7, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 800, paramsLora3, 8, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 900, paramsLora4, 9, &mux, &orderOfResponses) + go sendReq(ctx, openaiclient, &wg, 1000, paramsLora5, 10, &mux, &orderOfResponses) + wg.Wait() + + // Check the order in which the requests are handled + checkOrder(orderOfResponses) + }, + Entry("4 workers, max loras 3", "4", "3", checkOrderMaxLora3), + Entry("5 workers, max loras 3", "5", "3", checkOrderMaxLora3), + Entry("5 workers, max loras 5", "5", "5", checkOrderMaxLora5), + ) + + }) + + Context("Stress", func() { + It("Should work correctly with many simultaneous requests", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom, + "--time-to-first-token", "3000", "--max-num-seqs", "12", "--max-loras", "2", + "--lora-modules", + "{\"name\":\"lora0\",\"path\":\"/path/to/lora0\"}", + "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + "{\"name\":\"lora2\",\"path\":\"/path/to/lora2\"}", + "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", + "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", + } + + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + // Run 1000 requests for 5 loras simultaneously + numberOfRequests := 1000 + for i := range numberOfRequests { + go func() { + defer GinkgoRecover() + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: fmt.Sprintf("lora%d", i%5), + } + _, err := openaiclient.Chat.Completions.New(ctx, params) + Expect(err).NotTo(HaveOccurred()) + }() + } + + time.Sleep(2000 * time.Millisecond) + metricsResp, err := client.Get(metricsUrl) + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + + // max-num-seqs is 12, so number of running requests should be 12 + // and the number of waiting requests 1000-12=988 + Expect(metrics).To(ContainSubstring("vllm:num_requests_running{model_name=\"testmodel\"} 12")) + Expect(metrics).To(ContainSubstring("vllm:num_requests_waiting{model_name=\"testmodel\"} 988")) + + // max-loras is 2, so the last lora metric should be: + // running: two loras (doesn't matter which two) + // waiting: all the five loras + // (there can be more than one metric with the same timestamp, therefore we check all of them) + lastLoraMetrics, err := getLastLoraMetrics(strings.Split(string(data), "\n")) + Expect(err).NotTo(HaveOccurred()) + + allLoras := []string{"lora1", "lora2", "lora3", "lora4", "lora0"} + Expect( + isLoraMetricPresent(lastLoraMetrics, []string{"lora1", "lora2"}, allLoras) || + isLoraMetricPresent(lastLoraMetrics, []string{"lora1", "lora3"}, allLoras) || + isLoraMetricPresent(lastLoraMetrics, []string{"lora1", "lora4"}, allLoras) || + isLoraMetricPresent(lastLoraMetrics, []string{"lora1", "lora0"}, allLoras) || + isLoraMetricPresent(lastLoraMetrics, []string{"lora3", "lora2"}, allLoras) || + isLoraMetricPresent(lastLoraMetrics, []string{"lora4", "lora2"}, allLoras) || + isLoraMetricPresent(lastLoraMetrics, []string{"lora0", "lora2"}, allLoras) || + isLoraMetricPresent(lastLoraMetrics, []string{"lora3", "lora4"}, allLoras) || + isLoraMetricPresent(lastLoraMetrics, []string{"lora3", "lora0"}, allLoras) || + isLoraMetricPresent(lastLoraMetrics, []string{"lora4", "lora0"}, allLoras)). + To(BeTrue()) + }) + + It("Should work correctly with many simultaneous requests with many workers", func() { + runningMetric := "vllm:num_requests_running{model_name=\"testmodel\"}" + waitingMetric := "vllm:num_requests_waiting{model_name=\"testmodel\"}" + ctx := context.TODO() + args := []string{"cmd", "--model", modelName, "--mode", common.ModeRandom, + "--time-to-first-token", "2000", "--time-to-first-token-std-dev", "600", + "--max-num-seqs", "1000", "--max-loras", "2", "--max-waiting-queue-length", "1500", + "--lora-modules", + "{\"name\":\"lora0\",\"path\":\"/path/to/lora0\"}", + "{\"name\":\"lora1\",\"path\":\"/path/to/lora1\"}", + } + + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + // Run 2000 requests for 2 loras simultaneously + numberOfRequests := 2000 + for i := range numberOfRequests { + go func() { + defer GinkgoRecover() + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: fmt.Sprintf("lora%d", i%2), + } + _, err := openaiclient.Chat.Completions.New(ctx, params) + Expect(err).NotTo(HaveOccurred()) + }() + } + + time.Sleep(400 * time.Millisecond) + metricsResp, err := client.Get(metricsUrl) + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := strings.Split(string(data), "\n") + + // max-num-seqs is 1000, so number of running requests should be 1000 + // and the number of waiting requests 2000-1000=2000 + runningStr := findMetric(metrics, runningMetric) + Expect(runningStr).NotTo(Equal("")) + running, err := strconv.Atoi(runningStr) + Expect(err).NotTo(HaveOccurred()) + Expect(running).To(Equal(1000)) + waitingStr := findMetric(metrics, waitingMetric) + waiting, err := strconv.Atoi(waitingStr) + Expect(err).NotTo(HaveOccurred()) + Expect(waiting).To(Equal(1000)) + + time.Sleep(1500 * time.Millisecond) + + // After about 2 secs (the mean ttft), send 500 more requests + numberOfRequests = 500 + for i := range numberOfRequests { + go func() { + defer GinkgoRecover() + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + Model: fmt.Sprintf("lora%d", i%2), + } + _, err := openaiclient.Chat.Completions.New(ctx, params) + Expect(err).NotTo(HaveOccurred()) + }() + } + time.Sleep(400 * time.Millisecond) + metricsResp, err = client.Get(metricsUrl) + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + + data, err = io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics = strings.Split(string(data), "\n") + + // We sent 2500 requests, after about 2.5 seconds + // number of running requests should be 1000 + // and the number of waiting requests should be less than 1000 + runningStr = findMetric(metrics, runningMetric) + Expect(runningStr).NotTo(Equal("")) + running, err = strconv.Atoi(runningStr) + Expect(err).NotTo(HaveOccurred()) + Expect(running).To(Equal(1000)) + waitingStr = findMetric(metrics, waitingMetric) + waiting, err = strconv.Atoi(waitingStr) + Expect(err).NotTo(HaveOccurred()) + Expect(waiting).To(BeNumerically("<", 1000)) + + // Wait another second + time.Sleep(1000 * time.Millisecond) + metricsResp, err = client.Get(metricsUrl) + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + data, err = io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics = strings.Split(string(data), "\n") + + // number of running requests should be 1000 + // and the number of waiting requests should be less than 1000 + runningStr = findMetric(metrics, runningMetric) + Expect(runningStr).NotTo(Equal("")) + running, err = strconv.Atoi(runningStr) + Expect(err).NotTo(HaveOccurred()) + Expect(running).To(Equal(1000)) + waitingStr = findMetric(metrics, waitingMetric) + waiting, err = strconv.Atoi(waitingStr) + Expect(err).NotTo(HaveOccurred()) + Expect(waiting).To(BeNumerically("<", 1000)) + }) + }) +}) + +func sendReq(ctx context.Context, openaiclient openai.Client, wg *sync.WaitGroup, delay int, + params openai.ChatCompletionNewParams, reqNum int, mux *sync.RWMutex, orderOfResponses *[]int) { + defer GinkgoRecover() + defer wg.Done() + time.Sleep(time.Duration(delay) * time.Millisecond) + _, err := openaiclient.Chat.Completions.New(ctx, params) + Expect(err).NotTo(HaveOccurred()) + mux.Lock() + *orderOfResponses = append(*orderOfResponses, reqNum) + mux.Unlock() +} + +// Check the order of the delayed requests with max-loras=1 and two workers +// Three requests to lora1 (req numbers 0-2) +// after a delay four requests to lora2 (req numbers 4-7), +// after a delay one more request to lora1 (req number 3). +// All the requests to lora1 should be handled before the requests to lora2. +// The first two requests have to be 0-2, the next two should be one of the requests +// from the first batch (0-2) and the last request to lora1 (req number 3), the +// next four should be requests to lora2 (4-7) in no particular order. +func checkOrderMaxLora1Workers2(orderOfResponses []int) { + Expect(orderOfResponses).To(HaveLen(8)) + for i, reqNum := range orderOfResponses { + switch { + case i < 2: + Expect(reqNum).To(BeNumerically("<", 3)) + case i < 4: + Expect(reqNum).To(BeNumerically("<", 4)) + default: + Expect(reqNum >= 4 && reqNum < 8).To(BeTrue()) + } + } +} + +// Check the order of the delayed requests with max-loras=5 and five workers +// Three requests to lora1 (req numbers 0-2) +// after a delay four requests to lora2 (req numbers 4-7), +// after a delay one more request to lora1 (req number 3). +// The requests should be handled in the order they are sent. +// The exact order of first three requests to lora1 and the four +// requests to lora2 is not important. +// The first three should be 0-2, the next two should be 4-7, +// the rest can be in any order. +func checkOrderMaxLora5Workers5(orderOfResponses []int) { + for i, reqNum := range orderOfResponses { + switch { + case i < 3: + Expect(reqNum).To(BeNumerically("<", 3)) + case i < 5: + Expect(reqNum >= 4 && reqNum <= 7).To(BeTrue()) + default: + Expect(reqNum).To(BeNumerically(">=", 3)) + } + } +} + +// Check the order of the delayed requests with max-loras=5 and one worker +// Three requests to lora1 (req numbers 0-2) +// after a delay four requests to lora2 (req numbers 4-7), +// after a delay one more request to lora1 (req number 3). +// The requests should be handled in the order they are sent. +// The exact order of first three requests to lora1 and the four +// requests to lora2 is not important. +// The first three should be 0-2, the next one should be 3, +// the rest 4-7. +func checkOrder(orderOfResponses []int) { + for i, reqNum := range orderOfResponses { + switch { + case i < 3: + Expect(reqNum).To(BeNumerically("<", 3)) + case i == 3: + Expect(reqNum).To(Equal(3)) + default: + Expect(reqNum).To(BeNumerically(">", 3)) + } + } +} + +// Check the order of requests sent in specific order with one worker +// The requests are sent with delays to make sure they enter the queue +// in the order they are sent. +// The order of the requests is: +// 0-lora1 1-lora1 2-lora2 3-lora3 4-lora4 5-lora1 6-lora2 7-lora3 8-lora4 +// The expected order of processing: +// 015263748 +func checkOrderMaxLora1Workers1(orderOfResponses []int) { + expected := []int{0, 1, 5, 2, 6, 3, 7, 4, 8} + Expect(orderOfResponses).To(Equal(expected)) +} + +// Check the order of requests sent in specific order with two workers +// The requests are sent with delays to make sure they enter the queue +// in the order they are sent. +// The order of the requests is: +// 0-lora1 1-lora1 2-lora2 3-lora3 4-lora4 5-lora2 6-lora3 7-lora4 +// The expected order of processing: +// {01}{25}{36}{47} - the order inside the brackets doesn't matter +func checkOrderWorkers2(orderOfResponses []int) { + expected1 := []int{0, 1, 2, 5, 3, 6, 4, 7} + expected2 := []int{1, 0, 5, 2, 6, 3, 7, 4} + Expect(orderOfResponses).To(HaveLen(8)) + for i, reqNum := range orderOfResponses { + Expect(reqNum).To(Or(Equal(expected1[i]), Equal(expected2[i]))) + } +} + +// Check the order of requests sent in specific order with max loras = 3 +// The requests are sent with delays to make sure they enter the queue +// in the order they are sent. +// The order of the requests is: +// 0-lora1 1-lora1 2-lora2 3-lora3 4-lora4 5-lora5 +// 6-lora1 7-lora2 8-lora3 9-lora4 10-lora5 +// The expected order of processing: +// 0, 1, 2, 3, 6, 7, 8, {4, 9}, {5, 10} - the order inside the brackets doesn't matter +func checkOrderMaxLora3(orderOfResponses []int) { + expected1 := []int{0, 1, 2, 3, 6, 7, 8, 4, 9, 5, 10} + expected2 := []int{0, 1, 2, 3, 6, 7, 8, 4, 9, 10, 5} + expected3 := []int{0, 1, 2, 3, 6, 7, 8, 9, 4, 5, 10} + expected4 := []int{0, 1, 2, 3, 6, 7, 8, 9, 4, 10, 5} + Expect(orderOfResponses).To(HaveLen(11)) + for i, reqNum := range orderOfResponses { + Expect(reqNum).To(Or(Equal(expected1[i]), Equal(expected2[i]), + Equal(expected3[i]), Equal(expected4[i]))) + } +} + +// Check the order of requests sent in specific order with max loras = 5 +// The requests are sent with delays to make sure they enter the queue +// in the order they are sent. +// The order of the requests is: +// 0-lora1 1-lora1 2-lora2 3-lora3 4-lora4 5-lora5 +// 6-lora1 7-lora2 8-lora3 9-lora4 10-lora5 +// The expected order of processing: +// 0, 1, 2, 3, 4, 6, 7, 8, 9, 5, 10 +func checkOrderMaxLora5(orderOfResponses []int) { + expected := []int{0, 1, 2, 3, 4, 6, 7, 8, 9, 5, 10} + Expect(orderOfResponses).To(Equal(expected)) +}