1313// See the License for the specific language governing permissions and
1414// limitations under the License.
1515
16- use std:: { collections:: HashMap , str :: FromStr } ;
16+ use std:: collections:: HashMap ;
1717
1818use anyhow:: Result ;
1919use futures:: StreamExt ;
2020
21- use super :: { CompletionChoice , CompletionResponse } ;
21+ use super :: CompletionResponse ;
2222use crate :: protocols:: {
2323 codec:: { Message , SseCodecError } ,
2424 common:: FinishReason ,
@@ -98,22 +98,31 @@ impl DeltaAggregator {
9898 let state_choice =
9999 aggregator
100100 . choices
101- . entry ( choice. index )
101+ . entry ( choice. index as u64 )
102102 . or_insert ( DeltaChoice {
103- index : choice. index ,
103+ index : choice. index as u64 ,
104104 text : "" . to_string ( ) ,
105105 finish_reason : None ,
106106 logprobs : choice. logprobs ,
107107 } ) ;
108108
109109 state_choice. text . push_str ( & choice. text ) ;
110110
111- // todo - handle logprobs
112-
113- if let Some ( finish_reason) = choice. finish_reason {
114- let reason = FinishReason :: from_str ( & finish_reason) . ok ( ) ;
115- state_choice. finish_reason = reason;
116- }
111+ // TODO - handle logprobs
112+
113+ // Handle CompletionFinishReason -> FinishReason conversation
114+ state_choice. finish_reason = match choice. finish_reason {
115+ Some ( async_openai:: types:: CompletionFinishReason :: Stop ) => {
116+ Some ( FinishReason :: Stop )
117+ }
118+ Some ( async_openai:: types:: CompletionFinishReason :: Length ) => {
119+ Some ( FinishReason :: Length )
120+ }
121+ Some ( async_openai:: types:: CompletionFinishReason :: ContentFilter ) => {
122+ Some ( FinishReason :: ContentFilter )
123+ }
124+ None => None ,
125+ } ;
117126 }
118127 }
119128 aggregator
@@ -131,7 +140,7 @@ impl DeltaAggregator {
131140 let mut choices: Vec < _ > = aggregator
132141 . choices
133142 . into_values ( )
134- . map ( CompletionChoice :: from)
143+ . map ( async_openai :: types :: Choice :: from)
135144 . collect ( ) ;
136145
137146 choices. sort_by ( |a, b| a. index . cmp ( & b. index ) ) ;
@@ -148,12 +157,12 @@ impl DeltaAggregator {
148157 }
149158}
150159
151- impl From < DeltaChoice > for CompletionChoice {
160+ impl From < DeltaChoice > for async_openai :: types :: Choice {
152161 fn from ( delta : DeltaChoice ) -> Self {
153- let finish_reason = delta. finish_reason . map ( |reason| reason . to_string ( ) ) ;
162+ let finish_reason = delta. finish_reason . map ( Into :: into ) ;
154163
155- CompletionChoice {
156- index : delta. index ,
164+ async_openai :: types :: Choice {
165+ index : delta. index as u32 ,
157166 text : delta. text ,
158167 finish_reason,
159168 logprobs : delta. logprobs ,
@@ -178,25 +187,34 @@ impl CompletionResponse {
178187
179188#[ cfg( test) ]
180189mod tests {
181- use crate :: protocols :: openai :: completions :: { CompletionChoice , CompletionResponse } ;
190+ use std :: str :: FromStr ;
182191
183- use super :: * ;
184192 use futures:: stream;
185193
194+ use super :: * ;
195+ use crate :: protocols:: openai:: completions:: CompletionResponse ;
196+
186197 fn create_test_delta (
187198 index : u64 ,
188199 text : & str ,
189200 finish_reason : Option < String > ,
190201 ) -> Annotated < CompletionResponse > {
202+ // This will silently discard invalid_finish reason values and fall back
203+ // to None - totally fine since this is test code
204+ let finish_reason = finish_reason
205+ . as_deref ( )
206+ . and_then ( |s| FinishReason :: from_str ( s) . ok ( ) )
207+ . map ( Into :: into) ;
208+
191209 Annotated {
192210 data : Some ( CompletionResponse {
193211 id : "test_id" . to_string ( ) ,
194212 model : "meta/llama-3.1-8b" . to_string ( ) ,
195213 created : 1234567890 ,
196214 usage : None ,
197215 system_fingerprint : None ,
198- choices : vec ! [ CompletionChoice {
199- index,
216+ choices : vec ! [ async_openai :: types :: Choice {
217+ index: index as u32 ,
200218 text: text. to_string( ) ,
201219 finish_reason,
202220 logprobs: None ,
@@ -255,7 +273,10 @@ mod tests {
255273 let choice = & response. choices [ 0 ] ;
256274 assert_eq ! ( choice. index, 0 ) ;
257275 assert_eq ! ( choice. text, "Hello," . to_string( ) ) ;
258- assert_eq ! ( choice. finish_reason, Some ( "length" . to_string( ) ) ) ;
276+ assert_eq ! (
277+ choice. finish_reason,
278+ Some ( async_openai:: types:: CompletionFinishReason :: Length )
279+ ) ;
259280 assert ! ( choice. logprobs. is_none( ) ) ;
260281 }
261282
@@ -283,7 +304,10 @@ mod tests {
283304 let choice = & response. choices [ 0 ] ;
284305 assert_eq ! ( choice. index, 0 ) ;
285306 assert_eq ! ( choice. text, "Hello, world!" . to_string( ) ) ;
286- assert_eq ! ( choice. finish_reason, Some ( "stop" . to_string( ) ) ) ;
307+ assert_eq ! (
308+ choice. finish_reason,
309+ Some ( async_openai:: types:: CompletionFinishReason :: Stop )
310+ ) ;
287311 }
288312
289313 #[ tokio:: test]
@@ -297,16 +321,16 @@ mod tests {
297321 usage : None ,
298322 system_fingerprint : None ,
299323 choices : vec ! [
300- CompletionChoice {
324+ async_openai :: types :: Choice {
301325 index: 0 ,
302326 text: "Choice 0" . to_string( ) ,
303- finish_reason: Some ( "stop" . to_string ( ) ) ,
327+ finish_reason: Some ( async_openai :: types :: CompletionFinishReason :: Stop ) ,
304328 logprobs: None ,
305329 } ,
306- CompletionChoice {
330+ async_openai :: types :: Choice {
307331 index: 1 ,
308332 text: "Choice 1" . to_string( ) ,
309- finish_reason: Some ( "stop" . to_string ( ) ) ,
333+ finish_reason: Some ( async_openai :: types :: CompletionFinishReason :: Stop ) ,
310334 logprobs: None ,
311335 } ,
312336 ] ,
@@ -333,11 +357,17 @@ mod tests {
333357 let choice0 = & response. choices [ 0 ] ;
334358 assert_eq ! ( choice0. index, 0 ) ;
335359 assert_eq ! ( choice0. text, "Choice 0" . to_string( ) ) ;
336- assert_eq ! ( choice0. finish_reason, Some ( "stop" . to_string( ) ) ) ;
360+ assert_eq ! (
361+ choice0. finish_reason,
362+ Some ( async_openai:: types:: CompletionFinishReason :: Stop )
363+ ) ;
337364
338365 let choice1 = & response. choices [ 1 ] ;
339366 assert_eq ! ( choice1. index, 1 ) ;
340367 assert_eq ! ( choice1. text, "Choice 1" . to_string( ) ) ;
341- assert_eq ! ( choice1. finish_reason, Some ( "stop" . to_string( ) ) ) ;
368+ assert_eq ! (
369+ choice1. finish_reason,
370+ Some ( async_openai:: types:: CompletionFinishReason :: Stop )
371+ ) ;
342372 }
343373}
0 commit comments