@@ -74,36 +74,42 @@ def setup_class(cls):
7474
7575 def test_generate (self ):
7676 prompts = ["Hello, AI!" , "Tell me a joke" ]
77- outputs = self .client .generate (prompts )["completion_ids" ]
77+ outputs = self .client .generate (prompts )
78+ prompt_ids = outputs ["prompt_ids" ]
79+ completion_ids = outputs ["completion_ids" ]
7880
79- # Check that the output is a list
80- assert isinstance (outputs , list )
81+ # Check that the outputs are lists
82+ assert isinstance (prompt_ids , list )
83+ assert isinstance (completion_ids , list )
8184
82- # Check that the number of generated sequences is equal to the number of prompts
83- assert len (outputs ) == len (prompts )
85+ # Check that the number of sequences are equal to the number of prompts
86+ assert len (prompt_ids ) == len (prompts )
87+ assert len (completion_ids ) == len (prompts )
8488
85- # Check that the generated sequences are lists of integers
86- for seq in outputs :
89+ # Check that the sequences are lists of integers
90+ for seq in prompt_ids :
91+ assert all (isinstance (tok , int ) for tok in seq )
92+ for seq in completion_ids :
8793 assert all (isinstance (tok , int ) for tok in seq )
8894
8995 def test_generate_with_params (self ):
9096 prompts = ["Hello, AI!" , "Tell me a joke" ]
91- outputs = self .client .generate (prompts , n = 2 , repetition_penalty = 0.9 , temperature = 0.8 , max_tokens = 32 )[
97+ completion_ids = self .client .generate (prompts , n = 2 , repetition_penalty = 0.9 , temperature = 0.8 , max_tokens = 32 )[
9298 "completion_ids"
9399 ]
94100
95101 # Check that the output is a list
96- assert isinstance (outputs , list )
102+ assert isinstance (completion_ids , list )
97103
98104 # Check that the number of generated sequences is 2 times the number of prompts
99- assert len (outputs ) == 2 * len (prompts )
105+ assert len (completion_ids ) == 2 * len (prompts )
100106
101107 # Check that the generated sequences are lists of integers
102- for seq in outputs :
108+ for seq in completion_ids :
103109 assert all (isinstance (tok , int ) for tok in seq )
104110
105111 # Check that the length of the generated sequences is less than or equal to 32
106- for seq in outputs :
112+ for seq in completion_ids :
107113 assert len (seq ) <= 32
108114
109115 def test_update_model_params (self ):
@@ -148,36 +154,42 @@ def setup_class(cls):
148154
149155 def test_generate (self ):
150156 prompts = ["Hello, AI!" , "Tell me a joke" ]
151- outputs = self .client .generate (prompts )["completion_ids" ]
157+ outputs = self .client .generate (prompts )
158+ prompt_ids = outputs ["prompt_ids" ]
159+ completion_ids = outputs ["completion_ids" ]
152160
153- # Check that the output is a list
154- assert isinstance (outputs , list )
161+ # Check that the outputs are lists
162+ assert isinstance (prompt_ids , list )
163+ assert isinstance (completion_ids , list )
155164
156- # Check that the number of generated sequences is equal to the number of prompts
157- assert len (outputs ) == len (prompts )
165+ # Check that the number of sequences are equal to the number of prompts
166+ assert len (prompt_ids ) == len (prompts )
167+ assert len (completion_ids ) == len (prompts )
158168
159- # Check that the generated sequences are lists of integers
160- for seq in outputs :
169+ # Check that the sequences are lists of integers
170+ for seq in prompt_ids :
171+ assert all (isinstance (tok , int ) for tok in seq )
172+ for seq in completion_ids :
161173 assert all (isinstance (tok , int ) for tok in seq )
162174
163175 def test_generate_with_params (self ):
164176 prompts = ["Hello, AI!" , "Tell me a joke" ]
165- outputs = self .client .generate (prompts , n = 2 , repetition_penalty = 0.9 , temperature = 0.8 , max_tokens = 32 )[
177+ completion_ids = self .client .generate (prompts , n = 2 , repetition_penalty = 0.9 , temperature = 0.8 , max_tokens = 32 )[
166178 "completion_ids"
167179 ]
168180
169181 # Check that the output is a list
170- assert isinstance (outputs , list )
182+ assert isinstance (completion_ids , list )
171183
172184 # Check that the number of generated sequences is 2 times the number of prompts
173- assert len (outputs ) == 2 * len (prompts )
185+ assert len (completion_ids ) == 2 * len (prompts )
174186
175187 # Check that the generated sequences are lists of integers
176- for seq in outputs :
188+ for seq in completion_ids :
177189 assert all (isinstance (tok , int ) for tok in seq )
178190
179191 # Check that the length of the generated sequences is less than or equal to 32
180- for seq in outputs :
192+ for seq in completion_ids :
181193 assert len (seq ) <= 32
182194
183195 def test_update_model_params (self ):
@@ -224,16 +236,22 @@ def setup_class(cls):
224236
225237 def test_generate (self ):
226238 prompts = ["Hello, AI!" , "Tell me a joke" ]
227- outputs = self .client .generate (prompts )["completion_ids" ]
239+ outputs = self .client .generate (prompts )
240+ prompt_ids = outputs ["prompt_ids" ]
241+ completion_ids = outputs ["completion_ids" ]
228242
229- # Check that the output is a list
230- assert isinstance (outputs , list )
243+ # Check that the outputs are lists
244+ assert isinstance (prompt_ids , list )
245+ assert isinstance (completion_ids , list )
231246
232- # Check that the number of generated sequences is equal to the number of prompts
233- assert len (outputs ) == len (prompts )
247+ # Check that the number of sequences are equal to the number of prompts
248+ assert len (prompt_ids ) == len (prompts )
249+ assert len (completion_ids ) == len (prompts )
234250
235- # Check that the generated sequences are lists of integers
236- for seq in outputs :
251+ # Check that the sequences are lists of integers
252+ for seq in prompt_ids :
253+ assert all (isinstance (tok , int ) for tok in seq )
254+ for seq in completion_ids :
237255 assert all (isinstance (tok , int ) for tok in seq )
238256
239257 def test_update_model_params (self ):
@@ -280,16 +298,22 @@ def setup_class(cls):
280298
281299 def test_generate (self ):
282300 prompts = ["Hello, AI!" , "Tell me a joke" ]
283- outputs = self .client .generate (prompts )["completion_ids" ]
301+ outputs = self .client .generate (prompts )
302+ prompt_ids = outputs ["prompt_ids" ]
303+ completion_ids = outputs ["completion_ids" ]
284304
285- # Check that the output is a list
286- assert isinstance (outputs , list )
305+ # Check that the outputs are lists
306+ assert isinstance (prompt_ids , list )
307+ assert isinstance (completion_ids , list )
287308
288- # Check that the number of generated sequences is equal to the number of prompts
289- assert len (outputs ) == len (prompts )
309+ # Check that the number of sequences are equal to the number of prompts
310+ assert len (prompt_ids ) == len (prompts )
311+ assert len (completion_ids ) == len (prompts )
290312
291- # Check that the generated sequences are lists of integers
292- for seq in outputs :
313+ # Check that the sequences are lists of integers
314+ for seq in prompt_ids :
315+ assert all (isinstance (tok , int ) for tok in seq )
316+ for seq in completion_ids :
293317 assert all (isinstance (tok , int ) for tok in seq )
294318
295319 def test_update_model_params (self ):
@@ -336,9 +360,13 @@ def test_init_communicator_with_device_int(self):
336360
337361 # Test basic functionality
338362 prompts = ["Hello, AI!" ]
339- outputs = client .generate (prompts )["completion_ids" ]
340- assert isinstance (outputs , list )
341- assert len (outputs ) == len (prompts )
363+ outputs = client .generate (prompts )
364+ prompt_ids = outputs ["prompt_ids" ]
365+ completion_ids = outputs ["completion_ids" ]
366+ assert isinstance (prompt_ids , list )
367+ assert len (prompt_ids ) == len (prompts )
368+ assert isinstance (completion_ids , list )
369+ assert len (completion_ids ) == len (prompts )
342370
343371 client .close_communicator ()
344372
0 commit comments