@@ -94,6 +94,20 @@ def test_select_model_id_invalid_model(mocker):
9494 )
9595
9696
97+ def test_no_available_models (mocker ):
98+ """Test the select_model_id function with an invalid model."""
99+ mock_client = mocker .Mock ()
100+ # empty list of models
101+ mock_client .models .list .return_value = []
102+
103+ query_request = QueryRequest (query = "What is OpenStack?" , model = None , provider = None )
104+
105+ with pytest .raises (Exception ) as exc_info :
106+ select_model_id (mock_client , query_request )
107+
108+ assert "No LLM model found in available models" in str (exc_info .value )
109+
110+
97111def test_validate_attachments_metadata ():
98112 """Test the validate_attachments_metadata function."""
99113 attachments = [
@@ -151,7 +165,7 @@ def test_validate_attachments_metadata_invalid_content_type():
151165 )
152166
153167
154- def test_retrieve_response (mocker ):
168+ def test_retrieve_response_no_available_shields (mocker ):
155169 """Test the retrieve_response function."""
156170 mock_agent = mocker .Mock ()
157171 mock_agent .create_turn .return_value .output_message .content = "LLM answer"
@@ -172,3 +186,147 @@ def test_retrieve_response(mocker):
172186 documents = [],
173187 stream = False ,
174188 )
189+
190+
191+ def test_retrieve_response_one_available_shield (mocker ):
192+ """Test the retrieve_response function."""
193+
194+ class MockShield :
195+ def __init__ (self , identifier ):
196+ self .identifier = identifier
197+
198+ def identifier (self ):
199+ return self .identifier
200+
201+ mock_agent = mocker .Mock ()
202+ mock_agent .create_turn .return_value .output_message .content = "LLM answer"
203+ mock_client = mocker .Mock ()
204+ mock_client .shields .list .return_value = [MockShield ("shield1" )]
205+
206+ mocker .patch ("app.endpoints.query.Agent" , return_value = mock_agent )
207+
208+ query_request = QueryRequest (query = "What is OpenStack?" )
209+ model_id = "fake_model_id"
210+
211+ response = retrieve_response (mock_client , model_id , query_request )
212+
213+ assert response == "LLM answer"
214+ mock_agent .create_turn .assert_called_once_with (
215+ messages = [UserMessage (content = "What is OpenStack?" , role = "user" , context = None )],
216+ session_id = mocker .ANY ,
217+ documents = [],
218+ stream = False ,
219+ )
220+
221+
222+ def test_retrieve_response_two_available_shields (mocker ):
223+ """Test the retrieve_response function."""
224+
225+ class MockShield :
226+ def __init__ (self , identifier ):
227+ self .identifier = identifier
228+
229+ def identifier (self ):
230+ return self .identifier
231+
232+ mock_agent = mocker .Mock ()
233+ mock_agent .create_turn .return_value .output_message .content = "LLM answer"
234+ mock_client = mocker .Mock ()
235+ mock_client .shields .list .return_value = [
236+ MockShield ("shield1" ),
237+ MockShield ("shield2" ),
238+ ]
239+
240+ mocker .patch ("app.endpoints.query.Agent" , return_value = mock_agent )
241+
242+ query_request = QueryRequest (query = "What is OpenStack?" )
243+ model_id = "fake_model_id"
244+
245+ response = retrieve_response (mock_client , model_id , query_request )
246+
247+ assert response == "LLM answer"
248+ mock_agent .create_turn .assert_called_once_with (
249+ messages = [UserMessage (content = "What is OpenStack?" , role = "user" , context = None )],
250+ session_id = mocker .ANY ,
251+ documents = [],
252+ stream = False ,
253+ )
254+
255+
256+ def test_retrieve_response_with_one_attachment (mocker ):
257+ """Test the retrieve_response function."""
258+ mock_agent = mocker .Mock ()
259+ mock_agent .create_turn .return_value .output_message .content = "LLM answer"
260+ mock_client = mocker .Mock ()
261+ mock_client .shields .list .return_value = []
262+
263+ attachments = [
264+ Attachment (
265+ attachment_type = "log" ,
266+ content_type = "text/plain" ,
267+ content = "this is attachment" ,
268+ ),
269+ ]
270+ mocker .patch ("app.endpoints.query.Agent" , return_value = mock_agent )
271+
272+ query_request = QueryRequest (query = "What is OpenStack?" , attachments = attachments )
273+ model_id = "fake_model_id"
274+
275+ response = retrieve_response (mock_client , model_id , query_request )
276+
277+ assert response == "LLM answer"
278+ mock_agent .create_turn .assert_called_once_with (
279+ messages = [UserMessage (content = "What is OpenStack?" , role = "user" , context = None )],
280+ session_id = mocker .ANY ,
281+ stream = False ,
282+ documents = [
283+ {
284+ "content" : "this is attachment" ,
285+ "mime_type" : "text/plain" ,
286+ },
287+ ],
288+ )
289+
290+
291+ def test_retrieve_response_with_two_attachments (mocker ):
292+ """Test the retrieve_response function."""
293+ mock_agent = mocker .Mock ()
294+ mock_agent .create_turn .return_value .output_message .content = "LLM answer"
295+ mock_client = mocker .Mock ()
296+ mock_client .shields .list .return_value = []
297+
298+ attachments = [
299+ Attachment (
300+ attachment_type = "log" ,
301+ content_type = "text/plain" ,
302+ content = "this is attachment" ,
303+ ),
304+ Attachment (
305+ attachment_type = "configuration" ,
306+ content_type = "application/yaml" ,
307+ content = "kind: Pod\n metadata:\n name: private-reg" ,
308+ ),
309+ ]
310+ mocker .patch ("app.endpoints.query.Agent" , return_value = mock_agent )
311+
312+ query_request = QueryRequest (query = "What is OpenStack?" , attachments = attachments )
313+ model_id = "fake_model_id"
314+
315+ response = retrieve_response (mock_client , model_id , query_request )
316+
317+ assert response == "LLM answer"
318+ mock_agent .create_turn .assert_called_once_with (
319+ messages = [UserMessage (content = "What is OpenStack?" , role = "user" , context = None )],
320+ session_id = mocker .ANY ,
321+ stream = False ,
322+ documents = [
323+ {
324+ "content" : "this is attachment" ,
325+ "mime_type" : "text/plain" ,
326+ },
327+ {
328+ "content" : "kind: Pod\n " " metadata:\n " " name: private-reg" ,
329+ "mime_type" : "application/yaml" ,
330+ },
331+ ],
332+ )
0 commit comments