99
1010url = os .environ .get ("GRADIO_URL" , "http://localhost:7860" )
1111client = Client (url )
12-
13- class TestSuite (unittest .TestCase ):
14- # General tests
12+ latest_message = "Why don't humans drink horse milk?"
13+ history = [
14+ {
15+ "role" : "user" ,
16+ "metadata" : None ,
17+ "content" : "Hi!" ,
18+ "options" : None ,
19+ },
20+ {
21+ "role" : "assistant" ,
22+ "metadata" : None ,
23+ "content" : "Hello! How can I help you?" ,
24+ "options" : None ,
25+ },
26+ ]
27+
28+ class TestAPI (unittest .TestCase ):
1529 def test_gradio_api (self ):
1630 result = client .predict ("Hi" , api_name = "/chat" )
1731 self .assertGreater (len (result ), 0 )
1832
19- # build_chat_context function tests
33+ class TestBuildChatContext ( unittest . TestCase ):
2034 @patch ("app.settings" )
2135 @patch ("app.INCLUDE_SYSTEM_PROMPT" , True )
2236 def test_chat_context_system_prompt (self , mock_settings ):
2337 mock_settings .model_instruction = "You are a helpful assistant."
24- latest_message = "What is a mammal?"
25- history = [
26- {'role' : 'user' , 'metadata' : None , 'content' : 'Hi!' , 'options' : None },
27- {"role" : "assistant" , 'metadata' : None , "content" : "Hello! How can I help you?" , 'options' : None },
28- ]
2938
3039 context = build_chat_context (latest_message , history )
3140
3241 self .assertEqual (len (context ), 4 )
3342 self .assertIsInstance (context [0 ], SystemMessage )
3443 self .assertEqual (context [0 ].content , "You are a helpful assistant." )
3544 self .assertIsInstance (context [1 ], HumanMessage )
36- self .assertEqual (context [1 ].content , "Hi!" )
45+ self .assertEqual (context [1 ].content , history [ 0 ][ "content" ] )
3746 self .assertIsInstance (context [2 ], AIMessage )
38- self .assertEqual (context [2 ].content , "Hello! How can I help you?" )
47+ self .assertEqual (context [2 ].content , history [ 1 ][ "content" ] )
3948 self .assertIsInstance (context [3 ], HumanMessage )
4049 self .assertEqual (context [3 ].content , latest_message )
4150
4251 @patch ("app.settings" )
4352 @patch ("app.INCLUDE_SYSTEM_PROMPT" , False )
4453 def test_chat_context_human_prompt (self , mock_settings ):
4554 mock_settings .model_instruction = "You are a very helpful assistant."
46- latest_message = "What is a fish?"
47- history = [
48- {"role" : "user" , 'metadata' : None , "content" : "Hi there!" , 'options' : None },
49- {"role" : "assistant" , 'metadata' : None , "content" : "Hi! How can I help you?" , 'options' : None },
50- ]
5155
5256 context = build_chat_context (latest_message , history )
5357
5458 self .assertEqual (len (context ), 3 )
5559 self .assertIsInstance (context [0 ], HumanMessage )
56- self .assertEqual (context [0 ].content , "You are a very helpful assistant.\n \n Hi there !" )
60+ self .assertEqual (context [0 ].content , "You are a very helpful assistant.\n \n Hi!" )
5761 self .assertIsInstance (context [1 ], AIMessage )
58- self .assertEqual (context [1 ].content , "Hi! How can I help you?" )
62+ self .assertEqual (context [1 ].content , history [ 1 ][ "content" ] )
5963 self .assertIsInstance (context [2 ], HumanMessage )
6064 self .assertEqual (context [2 ].content , latest_message )
6165
62- # inference function tests
66+ class TestInference ( unittest . TestCase ):
6367 @patch ("app.settings" )
6468 @patch ("app.llm" )
6569 @patch ("app.log" )
6670 def test_inference_success (self , mock_logger , mock_llm , mock_settings ):
6771 mock_llm .stream .return_value = [MagicMock (content = "response_chunk" )]
6872
6973 mock_settings .model_instruction = "You are a very helpful assistant."
70- latest_message = "Why don't we drink horse milk?"
71- history = [
72- {"role" : "user" , 'metadata' : None , "content" : "Hi there!" , 'options' : None },
73- {"role" : "assistant" , 'metadata' : None , "content" : "Hi! How can I help you?" , 'options' : None },
74- ]
7574
7675 responses = list (inference (latest_message , history ))
7776
@@ -88,8 +87,6 @@ def test_inference_thinking_tags(self, mock_build_chat_context, mock_llm):
8887 MagicMock (content = "</think>" ),
8988 MagicMock (content = "final response" ),
9089 ]
91- latest_message = "Hello"
92- history = []
9390
9491 responses = list (inference (latest_message , history ))
9592
@@ -98,7 +95,8 @@ def test_inference_thinking_tags(self, mock_build_chat_context, mock_llm):
9895 @patch ("app.llm" )
9996 @patch ("app.INCLUDE_SYSTEM_PROMPT" , True )
10097 @patch ("app.build_chat_context" )
101- def test_inference_PossibleSystemPromptException (self , mock_build_chat_context , mock_llm ):
98+ @patch ("app.log" )
99+ def test_inference_PossibleSystemPromptException (self , mock_logger , mock_build_chat_context , mock_llm ):
102100 mock_build_chat_context .return_value = ["mock_context" ]
103101 mock_response = Mock ()
104102 mock_response .json .return_value = {"message" : "Bad request" }
@@ -109,16 +107,15 @@ def test_inference_PossibleSystemPromptException(self, mock_build_chat_context,
109107 body = None
110108 )
111109
112- latest_message = "Hello"
113- history = []
114-
115110 with self .assertRaises (PossibleSystemPromptException ):
116111 list (inference (latest_message , history ))
112+ mock_logger .error .assert_called_once_with ("Received BadRequestError from backend API: %s" , mock_llm .stream .side_effect )
117113
118114 @patch ("app.llm" )
119115 @patch ("app.INCLUDE_SYSTEM_PROMPT" , False )
120116 @patch ("app.build_chat_context" )
121- def test_inference_general_error (self , mock_build_chat_context , mock_llm ):
117+ @patch ("app.log" )
118+ def test_inference_general_error (self , mock_logger , mock_build_chat_context , mock_llm ):
122119 mock_build_chat_context .return_value = ["mock_context" ]
123120 mock_response = Mock ()
124121 mock_response .json .return_value = {"message" : "Bad request" }
@@ -129,13 +126,12 @@ def test_inference_general_error(self, mock_build_chat_context, mock_llm):
129126 body = None
130127 )
131128
132- latest_message = "Hello"
133- history = []
134129 exception_message = "\' API Error received. This usually means the chosen LLM uses an incompatible prompt format. Error message was: Bad request\' "
135130
136131 with self .assertRaises (gr .Error ) as gradio_error :
137132 list (inference (latest_message , history ))
138133 self .assertEqual (str (gradio_error .exception ), exception_message )
134+ mock_logger .error .assert_called_once_with ("Received BadRequestError from backend API: %s" , mock_llm .stream .side_effect )
139135
140136 @patch ("app.llm" )
141137 @patch ("app.build_chat_context" )
@@ -152,9 +148,6 @@ def test_inference_APIConnectionError(self, mock_gr, mock_logger, mock_build_cha
152148 request = mock_request ,
153149 )
154150
155- latest_message = "Hello"
156- history = []
157-
158151 list (inference (latest_message , history ))
159152 mock_logger .info .assert_any_call ("Backend API not yet ready" )
160153 mock_gr .Info .assert_any_call ("Backend not ready - model may still be initialising - please try again later." )
@@ -174,9 +167,6 @@ def test_inference_APIConnectionError_initialised(self, mock_gr, mock_logger, mo
174167 request = mock_request ,
175168 )
176169
177- latest_message = "Hello"
178- history = []
179-
180170 list (inference (latest_message , history ))
181171 mock_logger .error .assert_called_once_with ("Failed to connect to backend API: %s" , mock_llm .stream .side_effect )
182172 mock_gr .Warning .assert_any_call ("Failed to connect to backend API." )
@@ -195,11 +185,8 @@ def test_inference_InternalServerError(self, mock_gr, mock_build_chat_context, m
195185 body = None
196186 )
197187
198- latest_message = "Hello"
199- history = []
200-
201188 list (inference (latest_message , history ))
202189 mock_gr .Warning .assert_any_call ("Internal server error encountered in backend API - see API logs for details." )
203190
204191if __name__ == "__main__" :
205- unittest .main ()
192+ unittest .main (verbosity = 2 )
0 commit comments