4747 WisdomLogAwareMixin ,
4848)
4949
50+ VALIDATE_PROMPT = "---\n - hosts: all\n tasks:\n - name: install ssh\n "
51+
5052
5153@override_settings (DEPLOYMENT_MODE = "saas" )
5254@override_settings (WCA_SECRET_BACKEND_TYPE = "aws_sm" )
@@ -183,6 +185,9 @@ def _test_set_model_id(self, has_seat):
183185 self .user .organization = Organization .objects .get_or_create (id = 123 )[0 ]
184186 self .user .rh_user_has_seat = has_seat
185187 mock_secret_manager = apps .get_app_config ("ai" ).get_wca_secret_manager ()
188+ mock_wca_client : ModelPipelineCompletions = apps .get_app_config ("ai" ).get_model_pipeline (
189+ ModelPipelineCompletions
190+ )
186191 self .client .force_authenticate (user = self .user )
187192
188193 # ModelId should initially not exist
@@ -192,7 +197,15 @@ def _test_set_model_id(self, has_seat):
192197 mock_secret_manager .get_secret .assert_called_with (123 , Suffixes .MODEL_ID )
193198
194199 # Set ModelId
195- mock_secret_manager .get_secret .return_value = {"SecretString" : "someAPIKey" }
200+ api_key_value = "someAPIKey"
201+ model_id_value = "secret_model_id"
202+ mock_secret_manager .get_secret .return_value = {"SecretString" : api_key_value }
203+
204+ expected_headers = {"Authorization" : f"Bearer { api_key_value } " , "X-Test-Header-Set" : "true" }
205+ mock_wca_client .get_request_headers .return_value = expected_headers
206+ mock_wca_client .infer_from_parameters .reset_mock (side_effect = True )
207+ mock_wca_client .infer_from_parameters .side_effect = None
208+
196209 with self .assertLogs (logger = "ansible_ai_connect.users.signals" , level = "DEBUG" ) as signals :
197210 with self .assertLogs (logger = "root" , level = "DEBUG" ) as log :
198211 r = self .client .post (
@@ -202,8 +215,19 @@ def _test_set_model_id(self, has_seat):
202215 )
203216
204217 self .assertEqual (r .status_code , HTTPStatus .NO_CONTENT )
218+
219+ mock_wca_client .get_request_headers .assert_called_once_with (
220+ api_key = api_key_value , identifier = None , lightspeed_user_uuid = None
221+ )
222+ mock_wca_client .infer_from_parameters .assert_called_once_with (
223+ model_id_value ,
224+ "" ,
225+ VALIDATE_PROMPT ,
226+ user = None ,
227+ headers = expected_headers ,
228+ )
205229 mock_secret_manager .save_secret .assert_called_with (
206- 123 , Suffixes .MODEL_ID , "secret_model_id"
230+ 123 , Suffixes .MODEL_ID , model_id_value
207231 )
208232 self .assert_segment_log (log , "modelIdSet" , None )
209233
@@ -432,18 +456,45 @@ def _test_validate_ok(self, has_seat):
432456 self .user .organization = Organization .objects .get_or_create (id = 123 )[0 ]
433457 self .user .rh_user_has_seat = has_seat
434458 mock_secret_manager = apps .get_app_config ("ai" ).get_wca_secret_manager ()
459+ mock_wca_client : ModelPipelineCompletions = apps .get_app_config ("ai" ).get_model_pipeline (
460+ ModelPipelineCompletions
461+ )
435462 self .client .force_authenticate (user = self .user )
436463
437- def mock_get_secret_model_id (* args , ** kwargs ):
464+ api_key_value = "some_api_key_for_validate"
465+ model_id_value = "model_id_for_validate"
466+
467+ def mock_get_secret_side_effect (* args , ** kwargs ):
438468 if args [1 ] == Suffixes .API_KEY :
439- return {"SecretString" : "some_api_key" }
440- return {"SecretString" : "model_id" }
469+ return {"SecretString" : api_key_value }
470+ if args [1 ] == Suffixes .MODEL_ID :
471+ return {"SecretString" : model_id_value }
472+ return None
441473
442- mock_secret_manager .get_secret .side_effect = mock_get_secret_model_id
474+ mock_secret_manager .get_secret .side_effect = mock_get_secret_side_effect
475+
476+ expected_headers = {
477+ "Authorization" : f"Bearer { api_key_value } " ,
478+ "X-Test-Header-Validate" : "true" ,
479+ }
480+ mock_wca_client .get_request_headers .return_value = expected_headers
481+ mock_wca_client .infer_from_parameters .reset_mock (side_effect = True )
482+ mock_wca_client .infer_from_parameters .side_effect = None
443483
444484 with self .assertLogs (logger = "root" , level = "DEBUG" ) as log :
445485 r = self .client .get (self .api_version_reverse ("wca_model_id_validator" ))
446486 self .assertEqual (r .status_code , HTTPStatus .OK )
487+
488+ mock_wca_client .get_request_headers .assert_called_once_with (
489+ api_key = api_key_value , identifier = None , lightspeed_user_uuid = None
490+ )
491+ mock_wca_client .infer_from_parameters .assert_called_once_with (
492+ model_id_value ,
493+ "" ,
494+ VALIDATE_PROMPT ,
495+ user = None ,
496+ headers = expected_headers ,
497+ )
447498 self .assert_segment_log (log , "modelIdValidate" , None )
448499
449500 @override_settings (SEGMENT_WRITE_KEY = "DUMMY_KEY_VALUE" )
0 commit comments