Skip to content

Commit 55eeabc

Browse files
committed
refactor: refactored using Choice and CompletionFinishReason
1 parent 57f5725 commit 55eeabc

File tree

6 files changed

+123
-88
lines changed

6 files changed

+123
-88
lines changed

lib/llm/src/engines.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<Completi
237237
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
238238
id += 1;
239239
}
240-
let response = deltas.create_choice(0, None, Some("stop".to_string()));
240+
let response = deltas.create_choice(0, None, Some(async_openai::types::CompletionFinishReason::Stop));
241241
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
242242

243243
};

lib/llm/src/protocols/common.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ pub enum FinishReason {
6464

6565
#[serde(rename = "cancelled")]
6666
Cancelled,
67+
68+
#[serde(rename = "content_filter")]
69+
ContentFilter,
6770
}
6871

6972
impl std::fmt::Display for FinishReason {
@@ -74,6 +77,7 @@ impl std::fmt::Display for FinishReason {
7477
FinishReason::Stop => write!(f, "stop"),
7578
FinishReason::Error(msg) => write!(f, "error: {}", msg),
7679
FinishReason::Cancelled => write!(f, "cancelled"),
80+
FinishReason::ContentFilter => write!(f, "content_filter"),
7781
}
7882
}
7983
}
@@ -93,6 +97,33 @@ impl std::str::FromStr for FinishReason {
9397
}
9498
}
9599

100+
impl From<FinishReason> for async_openai::types::CompletionFinishReason {
101+
fn from(reason: FinishReason) -> Self {
102+
match reason {
103+
FinishReason::EoS | FinishReason::Stop | FinishReason::Cancelled => {
104+
async_openai::types::CompletionFinishReason::Stop
105+
}
106+
FinishReason::ContentFilter => {
107+
async_openai::types::CompletionFinishReason::ContentFilter
108+
}
109+
FinishReason::Length => async_openai::types::CompletionFinishReason::Length,
110+
FinishReason::Error(_) => async_openai::types::CompletionFinishReason::Stop,
111+
}
112+
}
113+
}
114+
115+
impl From<async_openai::types::CompletionFinishReason> for FinishReason {
116+
fn from(reason: async_openai::types::CompletionFinishReason) -> Self {
117+
match reason {
118+
async_openai::types::CompletionFinishReason::Stop => FinishReason::Stop,
119+
async_openai::types::CompletionFinishReason::Length => FinishReason::Length,
120+
async_openai::types::CompletionFinishReason::ContentFilter => {
121+
FinishReason::ContentFilter
122+
}
123+
}
124+
}
125+
}
126+
96127
/// LLM Inference Engines can accept a variety of input types. Not all Engines will support all
97128
/// input types. For example, the trtllm::AsyncEngine only supports `PromptType::Tokens` as an
98129
/// input type. The higher-level `Backend` class is a general wrapper around Engines that will

lib/llm/src/protocols/openai/chat_completions/delta.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
198198
Some(common::FinishReason::Stop) => Some(async_openai::types::FinishReason::Stop),
199199
Some(common::FinishReason::Length) => Some(async_openai::types::FinishReason::Length),
200200
Some(common::FinishReason::Cancelled) => Some(async_openai::types::FinishReason::Stop),
201+
Some(common::FinishReason::ContentFilter) => {
202+
Some(async_openai::types::FinishReason::ContentFilter)
203+
}
201204
Some(common::FinishReason::Error(err_msg)) => {
202205
return Err(anyhow::anyhow!(err_msg));
203206
}

lib/llm/src/protocols/openai/completions.rs

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pub struct CompletionResponse {
4949
pub id: String,
5050

5151
/// The list of completion choices the model generated for the input prompt.
52-
pub choices: Vec<CompletionChoice>,
52+
pub choices: Vec<async_openai::types::Choice>,
5353

5454
/// The Unix timestamp (in seconds) of when the completion was created.
5555
pub created: u64,
@@ -76,35 +76,12 @@ pub struct CompletionResponse {
7676
// pub nvext: Option<NimResponseExt>,
7777
}
7878

79-
/// Legacy OpenAI CompletionResponse Choice component
80-
#[derive(Clone, Debug, Deserialize, Serialize, Builder)]
81-
pub struct CompletionChoice {
82-
#[builder(setter(into))]
83-
pub text: String,
84-
85-
#[builder(default = "0")]
86-
pub index: u64,
87-
88-
#[builder(default, setter(into, strip_option))]
89-
pub finish_reason: Option<String>,
90-
91-
#[serde(skip_serializing_if = "Option::is_none")]
92-
#[builder(default, setter(strip_option))]
93-
pub logprobs: Option<async_openai::types::Logprobs>,
94-
}
95-
96-
impl ContentProvider for CompletionChoice {
79+
impl ContentProvider for async_openai::types::Choice {
9780
fn content(&self) -> String {
9881
self.text.clone()
9982
}
10083
}
10184

102-
impl CompletionChoice {
103-
pub fn builder() -> CompletionChoiceBuilder {
104-
CompletionChoiceBuilder::default()
105-
}
106-
}
107-
10885
pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String {
10986
match prompt {
11087
async_openai::types::Prompt::String(s) => s.clone(),
@@ -226,7 +203,7 @@ impl ResponseFactory {
226203

227204
pub fn make_response(
228205
&self,
229-
choice: CompletionChoice,
206+
choice: async_openai::types::Choice,
230207
usage: Option<async_openai::types::CompletionUsage>,
231208
) -> CompletionResponse {
232209
CompletionResponse {
@@ -294,27 +271,30 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
294271
}
295272
}
296273

297-
impl TryFrom<common::StreamingCompletionResponse> for CompletionChoice {
274+
impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choice {
298275
type Error = anyhow::Error;
299276

300277
fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
301-
let choice = CompletionChoice {
302-
text: response
303-
.delta
304-
.text
305-
.ok_or(anyhow::anyhow!("No text in response"))?,
306-
index: response.delta.index.unwrap_or(0) as u64,
307-
logprobs: None,
308-
finish_reason: match &response.delta.finish_reason {
309-
Some(common::FinishReason::EoS) => Some("stop".to_string()),
310-
Some(common::FinishReason::Stop) => Some("stop".to_string()),
311-
Some(common::FinishReason::Length) => Some("length".to_string()),
312-
Some(common::FinishReason::Error(err_msg)) => {
313-
return Err(anyhow::anyhow!("finish_reason::error = {}", err_msg));
314-
}
315-
Some(common::FinishReason::Cancelled) => Some("cancelled".to_string()),
316-
None => None,
317-
},
278+
let text = response
279+
.delta
280+
.text
281+
.ok_or(anyhow::anyhow!("No text in response"))?;
282+
283+
// Safety: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
284+
// so we're fairly safe knowing we won't generate that many Choices
285+
let index = response.delta.index.unwrap_or(0) as u32;
286+
287+
// TODO handle aggregating logprobs
288+
let logprobs = None;
289+
290+
let finish_reason: Option<async_openai::types::CompletionFinishReason> =
291+
response.delta.finish_reason.map(Into::into);
292+
293+
let choice = async_openai::types::Choice {
294+
text,
295+
index,
296+
logprobs,
297+
finish_reason,
318298
};
319299

320300
Ok(choice)

lib/llm/src/protocols/openai/completions/aggregator.rs

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
// See the License for the specific language governing permissions and
1414
// limitations under the License.
1515

16-
use std::{collections::HashMap, str::FromStr};
16+
use std::collections::HashMap;
1717

1818
use anyhow::Result;
1919
use futures::StreamExt;
2020

21-
use super::{CompletionChoice, CompletionResponse};
21+
use super::CompletionResponse;
2222
use crate::protocols::{
2323
codec::{Message, SseCodecError},
2424
common::FinishReason,
@@ -98,22 +98,31 @@ impl DeltaAggregator {
9898
let state_choice =
9999
aggregator
100100
.choices
101-
.entry(choice.index)
101+
.entry(choice.index as u64)
102102
.or_insert(DeltaChoice {
103-
index: choice.index,
103+
index: choice.index as u64,
104104
text: "".to_string(),
105105
finish_reason: None,
106106
logprobs: choice.logprobs,
107107
});
108108

109109
state_choice.text.push_str(&choice.text);
110110

111-
// todo - handle logprobs
112-
113-
if let Some(finish_reason) = choice.finish_reason {
114-
let reason = FinishReason::from_str(&finish_reason).ok();
115-
state_choice.finish_reason = reason;
116-
}
111+
// TODO - handle logprobs
112+
113+
// Handle CompletionFinishReason -> FinishReason conversation
114+
state_choice.finish_reason = match choice.finish_reason {
115+
Some(async_openai::types::CompletionFinishReason::Stop) => {
116+
Some(FinishReason::Stop)
117+
}
118+
Some(async_openai::types::CompletionFinishReason::Length) => {
119+
Some(FinishReason::Length)
120+
}
121+
Some(async_openai::types::CompletionFinishReason::ContentFilter) => {
122+
Some(FinishReason::ContentFilter)
123+
}
124+
None => None,
125+
};
117126
}
118127
}
119128
aggregator
@@ -131,7 +140,7 @@ impl DeltaAggregator {
131140
let mut choices: Vec<_> = aggregator
132141
.choices
133142
.into_values()
134-
.map(CompletionChoice::from)
143+
.map(async_openai::types::Choice::from)
135144
.collect();
136145

137146
choices.sort_by(|a, b| a.index.cmp(&b.index));
@@ -148,12 +157,12 @@ impl DeltaAggregator {
148157
}
149158
}
150159

151-
impl From<DeltaChoice> for CompletionChoice {
160+
impl From<DeltaChoice> for async_openai::types::Choice {
152161
fn from(delta: DeltaChoice) -> Self {
153-
let finish_reason = delta.finish_reason.map(|reason| reason.to_string());
162+
let finish_reason = delta.finish_reason.map(Into::into);
154163

155-
CompletionChoice {
156-
index: delta.index,
164+
async_openai::types::Choice {
165+
index: delta.index as u32,
157166
text: delta.text,
158167
finish_reason,
159168
logprobs: delta.logprobs,
@@ -178,25 +187,34 @@ impl CompletionResponse {
178187

179188
#[cfg(test)]
180189
mod tests {
181-
use crate::protocols::openai::completions::{CompletionChoice, CompletionResponse};
190+
use std::str::FromStr;
182191

183-
use super::*;
184192
use futures::stream;
185193

194+
use super::*;
195+
use crate::protocols::openai::completions::CompletionResponse;
196+
186197
fn create_test_delta(
187198
index: u64,
188199
text: &str,
189200
finish_reason: Option<String>,
190201
) -> Annotated<CompletionResponse> {
202+
// This will silently discard invalid_finish reason values and fall back
203+
// to None - totally fine since this is test code
204+
let finish_reason = finish_reason
205+
.as_deref()
206+
.and_then(|s| FinishReason::from_str(s).ok())
207+
.map(Into::into);
208+
191209
Annotated {
192210
data: Some(CompletionResponse {
193211
id: "test_id".to_string(),
194212
model: "meta/llama-3.1-8b".to_string(),
195213
created: 1234567890,
196214
usage: None,
197215
system_fingerprint: None,
198-
choices: vec![CompletionChoice {
199-
index,
216+
choices: vec![async_openai::types::Choice {
217+
index: index as u32,
200218
text: text.to_string(),
201219
finish_reason,
202220
logprobs: None,
@@ -255,7 +273,10 @@ mod tests {
255273
let choice = &response.choices[0];
256274
assert_eq!(choice.index, 0);
257275
assert_eq!(choice.text, "Hello,".to_string());
258-
assert_eq!(choice.finish_reason, Some("length".to_string()));
276+
assert_eq!(
277+
choice.finish_reason,
278+
Some(async_openai::types::CompletionFinishReason::Length)
279+
);
259280
assert!(choice.logprobs.is_none());
260281
}
261282

@@ -283,7 +304,10 @@ mod tests {
283304
let choice = &response.choices[0];
284305
assert_eq!(choice.index, 0);
285306
assert_eq!(choice.text, "Hello, world!".to_string());
286-
assert_eq!(choice.finish_reason, Some("stop".to_string()));
307+
assert_eq!(
308+
choice.finish_reason,
309+
Some(async_openai::types::CompletionFinishReason::Stop)
310+
);
287311
}
288312

289313
#[tokio::test]
@@ -297,16 +321,16 @@ mod tests {
297321
usage: None,
298322
system_fingerprint: None,
299323
choices: vec![
300-
CompletionChoice {
324+
async_openai::types::Choice {
301325
index: 0,
302326
text: "Choice 0".to_string(),
303-
finish_reason: Some("stop".to_string()),
327+
finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
304328
logprobs: None,
305329
},
306-
CompletionChoice {
330+
async_openai::types::Choice {
307331
index: 1,
308332
text: "Choice 1".to_string(),
309-
finish_reason: Some("stop".to_string()),
333+
finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
310334
logprobs: None,
311335
},
312336
],
@@ -333,11 +357,17 @@ mod tests {
333357
let choice0 = &response.choices[0];
334358
assert_eq!(choice0.index, 0);
335359
assert_eq!(choice0.text, "Choice 0".to_string());
336-
assert_eq!(choice0.finish_reason, Some("stop".to_string()));
360+
assert_eq!(
361+
choice0.finish_reason,
362+
Some(async_openai::types::CompletionFinishReason::Stop)
363+
);
337364

338365
let choice1 = &response.choices[1];
339366
assert_eq!(choice1.index, 1);
340367
assert_eq!(choice1.text, "Choice 1".to_string());
341-
assert_eq!(choice1.finish_reason, Some("stop".to_string()));
368+
assert_eq!(
369+
choice1.finish_reason,
370+
Some(async_openai::types::CompletionFinishReason::Stop)
371+
);
342372
}
343373
}

0 commit comments

Comments
 (0)