diff --git a/container/Dockerfile.sglang-deepep b/container/Dockerfile.sglang-deepep index 441e3f56c1..c95891209c 100644 --- a/container/Dockerfile.sglang-deepep +++ b/container/Dockerfile.sglang-deepep @@ -35,22 +35,62 @@ ARG ARCH_ALT=x86_64 WORKDIR /sgl-workspace +# Install UCX dependencies +RUN apt-get update -y && \ + apt-get install -y --no-install-recommends \ + --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev \ + libnuma-dev librdmacm-dev ibverbs-providers \ + autoconf libtool + +# Build UCX from source +ARG NIXL_UCX_REF=v1.19.x +RUN rm -rf /opt/hpcx/ucx && \ + rm -rf /usr/local/ucx && \ + cd /usr/local/src && \ + git clone https://github.com/openucx/ucx.git && \ + cd ucx && \ + git checkout $NIXL_UCX_REF && \ + ./autogen.sh && ./configure \ + --prefix=/usr/local/ucx \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=/usr/local/cuda \ + --with-verbs \ + --with-efa \ + --with-dm \ + --with-gdrcopy=/usr/local \ + --enable-mt && \ + make -j && \ + make -j install-strip && \ + ldconfig + +ENV LD_LIBRARY_PATH=/usr/lib:/usr/local/ucx/lib:$LD_LIBRARY_PATH + # Pinning to NIXL 0.2.1 right now # TODO: investigate pip install failure with 0.3.0 release ARG NIXL_COMMIT="5e4c179ee850d482a83cb2a211e0947e46281060" -RUN git clone https://github.com/ai-dynamo/nixl.git && cd nixl && git checkout ${NIXL_COMMIT} &&pip install --break-system-packages . --config-settings=setup-args="-Ducx_path=/opt/hpcx/ucx" +RUN git clone https://github.com/ai-dynamo/nixl.git && cd nixl && git checkout ${NIXL_COMMIT} && pip install --break-system-packages . --config-settings=setup-args="-Ducx_path=/usr/local/ucx" WORKDIR /sgl-workspace RUN pip uninstall --break-system-packages -y sglang RUN rm -rf sglang -# 0.4.7 -RUN pip install --break-system-packages "sglang==0.4.7" +# 0.4.8 has a bug with CUDA graphs and decode worker +# https://github.com/sgl-project/sglang/issues/7511 +RUN pip install --break-system-packages "sglang==0.4.7.post1" + +# Allow forceful shutdown of inflight requests +ENV SGL_FORCE_SHUTDOWN=1 WORKDIR /sgl-workspace # https://github.com/ai-dynamo/dynamo/pull/1510 ARG DYNAMO_COMMIT="382e3aedc421b3b3abc338062b332b54b5aa8529" -RUN git clone https://github.com/ai-dynamo/dynamo.git && cd dynamo && git checkout ${DYNAMO_COMMIT} +ARG DYNAMO_BRANCH="ishan/cmpl-token-id" +RUN git clone https://github.com/ai-dynamo/dynamo.git && cd dynamo && git checkout ${DYNAMO_BRANCH} # install dynamo in editable mode WORKDIR /sgl-workspace/dynamo diff --git a/examples/sglang/README.md b/examples/sglang/README.md index 96b7c5f532..9e9c36905b 100644 --- a/examples/sglang/README.md +++ b/examples/sglang/README.md @@ -106,12 +106,12 @@ Dynamo supports SGLang's implementation of wide expert parallelism and large sca Steps to run: -1. Build the SGLang DeepEP container +1. Build the SGLang DeepEP container. ```bash -git clone https://github.com/sgl-project/sglang.git +git clone -b v0.4.8 https://github.com/sgl-project/sglang.git cd sglang/docker -docker build -f Dockerfile.deepep -t deepep . +docker build -f Dockerfile -t deepep . ``` You will now have a `deepep:latest` image diff --git a/examples/sglang/components/decode_worker.py b/examples/sglang/components/decode_worker.py index 066c9108db..150d0fbed5 100644 --- a/examples/sglang/components/decode_worker.py +++ b/examples/sglang/components/decode_worker.py @@ -45,7 +45,9 @@ def __init__(self): @endpoint() async def generate(self, req: DisaggPreprocessedRequest): g = await self.engine.async_generate( - input_ids=req.request.token_ids, + input_ids=req.request.token_ids + if req.request.batch_token_ids is None + else req.request.batch_token_ids, sampling_params=req.sampling_params, stream=True, bootstrap_host=req.bootstrap_host, diff --git a/examples/sglang/components/worker.py b/examples/sglang/components/worker.py index c737f8abf2..1e4c868064 100644 --- a/examples/sglang/components/worker.py +++ b/examples/sglang/components/worker.py @@ -28,6 +28,7 @@ import logging import random import socket +from typing import Dict, Union import sglang as sgl from components.decode_worker import SGLangDecodeWorker @@ -112,63 +113,123 @@ def _build_sampling_params(self, request: PreprocessedRequest) -> dict: sampling_params["ignore_eos"] = request.stop_conditions.ignore_eos return sampling_params + def _get_request_batch_size(self, request: PreprocessedRequest): + """Get batch size from request, returns None for single requests""" + if request.batch_token_ids is not None: + return len(request.batch_token_ids) + return None + + def _is_batch_request(self, request: PreprocessedRequest): + """Check if request is in batch mode""" + return request.batch_token_ids is not None + @endpoint() async def generate(self, request: PreprocessedRequest): + # Check if we're in batch mode at the start + is_batch = self._is_batch_request(request) + batch_size = self._get_request_batch_size(request) + # TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput sampling_params = self._build_sampling_params(request) if self.engine_args.disaggregation_mode != "null": - bootstrap_room = self._generate_bootstrap_room() + if is_batch: + bootstrap_room = [ + self._generate_bootstrap_room() for _ in range(batch_size) + ] + bootstrap_host = [self.bootstrap_host] * batch_size + bootstrap_port = [self.bootstrap_port] * batch_size + else: + bootstrap_host = self.bootstrap_host + bootstrap_port = self.bootstrap_port + bootstrap_room = self._generate_bootstrap_room() # decode worker request disagg_request = DisaggPreprocessedRequest( request=request, sampling_params=sampling_params, - bootstrap_host=self.bootstrap_host, - bootstrap_port=self.bootstrap_port, + bootstrap_host=bootstrap_host, + bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, ) # prefill response is not used prefill = await self.engine.async_generate( - input_ids=request.token_ids, + input_ids=request.token_ids + if not is_batch + else request.batch_token_ids, sampling_params=sampling_params, stream=True, - bootstrap_host=self.bootstrap_host, - bootstrap_port=self.bootstrap_port, + bootstrap_host=bootstrap_host, + bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, ) prefill_task = asyncio.create_task(self._prefill_generator(prefill)) decode = await self.decode_client.generate(disagg_request.model_dump_json()) - async for out in self._process_stream(decode, unpack=True): + async for out in self._process_stream( + decode, unpack=True, is_batch=is_batch + ): yield out await prefill_task else: g = await self.engine.async_generate( - input_ids=request.token_ids, + input_ids=request.token_ids + if not is_batch + else request.batch_token_ids, sampling_params=sampling_params, stream=True, ) - async for out in self._process_stream(g, unpack=False): + async for out in self._process_stream(g, unpack=False, is_batch=is_batch): yield out - async def _process_stream(self, stream_source, unpack: bool): - num_output_tokens_so_far = 0 + async def _process_stream(self, stream_source, unpack: bool, is_batch: bool): + # Initialize based on batch mode + num_output_tokens_so_far: Union[Dict[int, int], int] + if is_batch: + num_output_tokens_so_far = {} + else: + num_output_tokens_so_far = 0 + async for res in stream_source: data = res.data() if unpack else res finish_reason = data["meta_info"]["finish_reason"] - if finish_reason: - # Don't forward the stop token - out = {"token_ids": [], "finish_reason": finish_reason["type"]} + + if is_batch: + # Handle batch response + assert isinstance(num_output_tokens_so_far, dict) + index = data.get("index", 0) + if index not in num_output_tokens_so_far: + num_output_tokens_so_far[index] = 0 + + if finish_reason: + out = { + "token_ids": [], + "finish_reason": finish_reason["type"], + "index": index, + } + else: + next_total_toks = len(data["output_ids"]) + new_tokens = data["output_ids"][num_output_tokens_so_far[index] :] + out = { + "token_ids": new_tokens, + "index": index, + } + num_output_tokens_so_far[index] = next_total_toks else: - next_total_toks = len(data["output_ids"]) - out = {"token_ids": data["output_ids"][num_output_tokens_so_far:]} + # Handle single response + assert isinstance(num_output_tokens_so_far, int) + if finish_reason: + out = {"token_ids": [], "finish_reason": finish_reason["type"]} + else: + next_total_toks = len(data["output_ids"]) + out = {"token_ids": data["output_ids"][num_output_tokens_so_far:]} + num_output_tokens_so_far = next_total_toks + yield out - num_output_tokens_so_far = next_total_toks def _generate_bootstrap_room(self): return random.randint(0, 2**63 - 1) diff --git a/examples/sglang/utils/protocol.py b/examples/sglang/utils/protocol.py index 6a38eaf52a..15ca7b20b2 100644 --- a/examples/sglang/utils/protocol.py +++ b/examples/sglang/utils/protocol.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Union from pydantic import BaseModel, Field @@ -47,6 +47,7 @@ class SamplingOptions(BaseModel): class PreprocessedRequest(BaseModel): token_ids: List[TokenIdType] + batch_token_ids: Optional[List[List[TokenIdType]]] = None stop_conditions: StopConditions sampling_options: SamplingOptions eos_token_ids: List[TokenIdType] = Field(default_factory=list) @@ -57,7 +58,7 @@ class PreprocessedRequest(BaseModel): class DisaggPreprocessedRequest(BaseModel): request: PreprocessedRequest sampling_params: dict - bootstrap_host: str - bootstrap_port: int - bootstrap_room: int + bootstrap_host: Union[str, List[str]] + bootstrap_port: Union[int, List[int]] + bootstrap_room: Union[int, List[int]] data_parallel_rank: Optional[int] = None diff --git a/launch/dynamo-run/src/subprocess/sglang_inc.py b/launch/dynamo-run/src/subprocess/sglang_inc.py index e7cf0b0b86..469b8a35db 100644 --- a/launch/dynamo-run/src/subprocess/sglang_inc.py +++ b/launch/dynamo-run/src/subprocess/sglang_inc.py @@ -60,22 +60,71 @@ async def generate(self, request): # sglang defaults this to 128 "max_new_tokens": request["stop_conditions"]["max_tokens"], } - num_output_tokens_so_far = 0 - gen = await self.engine_client.async_generate( - input_ids=request["token_ids"], sampling_params=sampling_params, stream=True - ) + + # Check if this is a batch request + is_batch = "batch_token_ids" in request and request["batch_token_ids"] + + if is_batch: + # Track tokens separately for each batch item + num_output_tokens_so_far = {} + logging.debug("received batch token ids") + gen = await self.engine_client.async_generate( + input_ids=request["batch_token_ids"], + sampling_params=sampling_params, + stream=True, + ) + else: + num_output_tokens_so_far = 0 + logging.debug("received token ids") + gen = await self.engine_client.async_generate( + input_ids=request["token_ids"], + sampling_params=sampling_params, + stream=True, + ) + async for res in gen: # res is a dict - + logging.debug(f"res: {res}") finish_reason = res["meta_info"]["finish_reason"] - if finish_reason: - # Don't forward the stop token - out = {"token_ids": [], "finish_reason": finish_reason["type"]} + + if is_batch: + # Handle batch response - get index from SGLang response + index = res.get("index", 0) + if index not in num_output_tokens_so_far: + num_output_tokens_so_far[index] = 0 + + if finish_reason: + logging.warning(f"finish_reason: {finish_reason}") + # Final response for this batch item + out = { + "token_ids": [], + "finish_reason": finish_reason["type"], + "index": index, + } + else: + # Streaming response for this batch item + next_total_toks = len(res["output_ids"]) + new_tokens = res["output_ids"][num_output_tokens_so_far[index] :] + out = { + "token_ids": new_tokens, + "index": index, + } + num_output_tokens_so_far[index] = next_total_toks else: - next_total_toks = len(res["output_ids"]) - out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]} + if finish_reason: + out = { + "token_ids": [], + "finish_reason": finish_reason["type"], + } + else: + next_total_toks = len(res["output_ids"]) + new_tokens = res["output_ids"][num_output_tokens_so_far:] + out = { + "token_ids": new_tokens, + } + num_output_tokens_so_far = next_total_toks + yield out - num_output_tokens_so_far = next_total_toks class EmbeddingRequestHandler(RequestHandler): diff --git a/lib/engines/llamacpp/src/lib.rs b/lib/engines/llamacpp/src/lib.rs index 5cf5dc9ebc..768421e610 100644 --- a/lib/engines/llamacpp/src/lib.rs +++ b/lib/engines/llamacpp/src/lib.rs @@ -269,6 +269,7 @@ fn run_request( cum_log_probs: None, // TODO output.cumulative_logprob.map(|v| v as f64), log_probs: None, // TODO output.logprobs finish_reason: None, + index: None, }; work_request .response_channel diff --git a/lib/llm/src/backend.rs b/lib/llm/src/backend.rs index 5af5c078e2..06f6096539 100644 --- a/lib/llm/src/backend.rs +++ b/lib/llm/src/backend.rs @@ -224,6 +224,7 @@ impl log_probs: data.log_probs, finish_reason: data.finish_reason, //mdcsum: mdcsum.clone(), + index: data.index, }) }) }); diff --git a/lib/llm/src/engines.rs b/lib/llm/src/engines.rs index 81be029659..95527ffff3 100644 --- a/lib/llm/src/engines.rs +++ b/lib/llm/src/engines.rs @@ -115,6 +115,7 @@ fn delta_core(tok: u32) -> Annotated { cum_log_probs: None, log_probs: None, finish_reason: None, + index: None, }; Annotated::from_data(delta) } diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 9d3b6f6fb3..12afd9c622 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -53,7 +53,7 @@ use crate::protocols::{ }; use crate::tokenizers::{traits::Tokenizer, HuggingFaceTokenizer}; -use crate::preprocessor::prompt::PromptFormatter; +use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput}; pub use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest}; @@ -160,33 +160,79 @@ impl OpenAIPreprocessor { let mut annotations = HashMap::new(); let mut builder = PreprocessedRequest::builder(); - let use_raw_prompt = request - .nvext() - .is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false)); - - let formatted_prompt = if use_raw_prompt { - match request.raw_prompt() { - Some(prompt) => prompt, - None => { - tracing::warn!("Raw prompt requested but not available"); - self.formatter.render(request)? + // match request type before any conversion/processing + match request.prompt_input_type() { + PromptInput::Tokens(_) => { + if let Some(token_input) = request.extract_tokens() { + match token_input { + TokenInput::Single(tokens) => { + builder.token_ids(tokens); + } + TokenInput::Batch(token_batches) => { + if token_batches.len() == 1 { + builder.token_ids(token_batches[0].clone()); + } else { + builder.batch_token_ids(Some(token_batches)); + builder.token_ids(vec![]); + } + } + } } } - } else { - self.formatter.render(request)? - }; - - let encoding = tokio::task::block_in_place(|| self.tokenizer.encode(&formatted_prompt))?; + PromptInput::Text(_) => { + if let Some(text_input) = request.extract_text() { + match text_input { + TextInput::Single(_) => { + let use_raw_prompt = request + .nvext() + .is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false)); + + let formatted_prompt = if use_raw_prompt { + match request.raw_prompt() { + Some(prompt) => prompt, + None => { + tracing::warn!("Raw prompt requested but not available"); + self.formatter.render(request)? + } + } + } else { + self.formatter.render(request)? + }; + + let encoding = tokio::task::block_in_place(|| { + self.tokenizer.encode(&formatted_prompt) + })?; + + if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) { + annotations.insert( + ANNOTATION_FORMATTED_PROMPT.to_string(), + formatted_prompt, + ); + } - if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) { - annotations.insert(ANNOTATION_FORMATTED_PROMPT.to_string(), formatted_prompt); - } + if request.has_annotation(ANNOTATION_TOKEN_IDS) { + annotations.insert( + ANNOTATION_TOKEN_IDS.to_string(), + serde_json::to_string(&encoding.token_ids)?, + ); + } - if request.has_annotation(ANNOTATION_TOKEN_IDS) { - annotations.insert( - ANNOTATION_TOKEN_IDS.to_string(), - serde_json::to_string(&encoding.token_ids)?, - ); + builder.token_ids(encoding.token_ids); + } + TextInput::Batch(texts) => { + let mut token_batches = Vec::new(); + // TODO: room for optimization here + for text in texts { + let encoding = + tokio::task::block_in_place(|| self.tokenizer.encode(&text))?; + token_batches.push(encoding.token_ids); + } + builder.batch_token_ids(Some(token_batches)); + builder.token_ids(vec![]); + } + } + } + } } let mut stop_conditions = request.extract_stop_conditions()?; @@ -207,9 +253,8 @@ impl OpenAIPreprocessor { builder.eos_token_ids(self.model_info.eos_token_ids()); } - builder.token_ids(encoding.token_ids); - builder.sampling_options(request.extract_sampling_options()?); builder.stop_conditions(stop_conditions); + builder.sampling_options(request.extract_sampling_options()?); builder.annotations(request.annotations().unwrap_or_default()); builder.mdc_sum(Some(self.mdcsum.clone())); builder.estimated_prefix_hit_num_blocks(None); diff --git a/lib/llm/src/preprocessor/prompt.rs b/lib/llm/src/preprocessor/prompt.rs index 1a7d129b6d..0a93926a1f 100644 --- a/lib/llm/src/preprocessor/prompt.rs +++ b/lib/llm/src/preprocessor/prompt.rs @@ -38,6 +38,24 @@ mod template; pub use template::ContextMixins; +#[derive(Debug)] +pub enum TokenInput { + Single(Vec), + Batch(Vec>), +} + +#[derive(Debug)] +pub enum TextInput { + Single(String), + Batch(Vec), +} + +#[derive(Debug)] +pub enum PromptInput { + Tokens(TokenInput), + Text(TextInput), +} + /// Trait that defines a request that can map to an OpenAI-like request. pub trait OAIChatLikeRequest { fn messages(&self) -> Value; @@ -49,6 +67,20 @@ pub trait OAIChatLikeRequest { } fn should_add_generation_prompt(&self) -> bool; + + /// Returns the type of input for the prompt. Default is Text. + fn prompt_input_type(&self) -> PromptInput { + PromptInput::Text(TextInput::Single(String::new())) + } + + /// Extract tokens if the input is pre-tokenized + fn extract_tokens(&self) -> Option { + None + } + + fn extract_text(&self) -> Option { + None + } } pub trait OAIPromptFormatter: Send + Sync + 'static { diff --git a/lib/llm/src/preprocessor/prompt/template/oai.rs b/lib/llm/src/preprocessor/prompt/template/oai.rs index eb384464dc..0cc4d3b7e8 100644 --- a/lib/llm/src/preprocessor/prompt/template/oai.rs +++ b/lib/llm/src/preprocessor/prompt/template/oai.rs @@ -22,6 +22,8 @@ use crate::protocols::openai::{ }; use tracing; +use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput}; + impl OAIChatLikeRequest for NvCreateChatCompletionRequest { fn messages(&self) -> Value { Value::from_serialize(&self.inner.messages) @@ -53,6 +55,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest { true } } + + fn extract_text(&self) -> Option { + Some(TextInput::Single(String::new())) + } } impl OAIChatLikeRequest for NvCreateCompletionRequest { @@ -72,6 +78,48 @@ impl OAIChatLikeRequest for NvCreateCompletionRequest { fn should_add_generation_prompt(&self) -> bool { true } + + fn prompt_input_type(&self) -> PromptInput { + match &self.inner.prompt { + async_openai::types::Prompt::IntegerArray(_) => { + PromptInput::Tokens(TokenInput::Single(vec![])) + } + async_openai::types::Prompt::ArrayOfIntegerArray(_) => { + PromptInput::Tokens(TokenInput::Batch(vec![])) + } + async_openai::types::Prompt::String(_) => { + PromptInput::Text(TextInput::Single(String::new())) + } + async_openai::types::Prompt::StringArray(_) => { + PromptInput::Text(TextInput::Batch(vec![])) + } + } + } + + fn extract_tokens(&self) -> Option { + match &self.inner.prompt { + async_openai::types::Prompt::IntegerArray(tokens) => Some(TokenInput::Single( + tokens.iter().map(|&t| t as u32).collect(), + )), + async_openai::types::Prompt::ArrayOfIntegerArray(arrays) => Some(TokenInput::Batch( + arrays + .iter() + .map(|arr| arr.iter().map(|&t| t as u32).collect()) + .collect(), + )), + _ => None, + } + } + + fn extract_text(&self) -> Option { + match &self.inner.prompt { + async_openai::types::Prompt::String(text) => Some(TextInput::Single(text.to_string())), + async_openai::types::Prompt::StringArray(texts) => { + Some(TextInput::Batch(texts.to_vec())) + } + _ => None, + } + } } impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { diff --git a/lib/llm/src/protocols/common/llm_backend.rs b/lib/llm/src/protocols/common/llm_backend.rs index b9e4d08a79..92b68b2b4e 100644 --- a/lib/llm/src/protocols/common/llm_backend.rs +++ b/lib/llm/src/protocols/common/llm_backend.rs @@ -46,6 +46,9 @@ pub struct BackendOutput { pub finish_reason: Option, // Model Deployment Card checksum //pub mdcsum: String, + + // Index field for batch requests to match OpenAI format + pub index: Option, } /// The LLM engine and backnd with manage it's own state, specifically translating how a @@ -77,6 +80,9 @@ pub struct LLMEngineOutput { // TODO: Enrich this with more information as can apply our first-level postprocessing // logic and return more detailed information pub finish_reason: Option, + + // Index field for batch requests to match OpenAI format + pub index: Option, } impl LLMEngineOutput { @@ -88,6 +94,7 @@ impl LLMEngineOutput { cum_log_probs: None, log_probs: None, finish_reason: Some(FinishReason::Cancelled), + index: None, } } @@ -99,6 +106,7 @@ impl LLMEngineOutput { cum_log_probs: None, log_probs: None, finish_reason: Some(FinishReason::Stop), + index: None, } } @@ -110,6 +118,7 @@ impl LLMEngineOutput { cum_log_probs: None, log_probs: None, finish_reason: Some(FinishReason::Length), + index: None, } } @@ -121,6 +130,7 @@ impl LLMEngineOutput { cum_log_probs: None, log_probs: None, finish_reason: Some(FinishReason::Error(err_msg)), + index: None, } } } diff --git a/lib/llm/src/protocols/common/preprocessor.rs b/lib/llm/src/protocols/common/preprocessor.rs index 6b3be76069..34ff3d072d 100644 --- a/lib/llm/src/protocols/common/preprocessor.rs +++ b/lib/llm/src/protocols/common/preprocessor.rs @@ -26,6 +26,10 @@ pub struct PreprocessedRequest { /// Type of prompt pub token_ids: Vec, + /// Batch Token Ids = for batch completion requests (i.e using ArrayOfIntegerArray type from OpenAI /completions) + #[builder(default)] + pub batch_token_ids: Option>>, + /// StopConditions are conditions that the inference engine will use to stop generation. pub stop_conditions: StopConditions, diff --git a/lib/llm/src/protocols/openai/completions/delta.rs b/lib/llm/src/protocols/openai/completions/delta.rs index d0b617850b..542f2a73c1 100644 --- a/lib/llm/src/protocols/openai/completions/delta.rs +++ b/lib/llm/src/protocols/openai/completions/delta.rs @@ -131,8 +131,9 @@ impl crate::protocols::openai::DeltaGeneratorExt for DeltaGe }; // create choice - let index = 0; - Ok(self.create_choice(index, delta.text, finish_reason)) + let index = delta.index.unwrap_or(0).into(); + let response = self.create_choice(index, delta.text.clone(), finish_reason); + Ok(response) } fn get_isl(&self) -> Option {