Skip to content

Commit 0db5143

Browse files
authored
AAP-18027: WCAClient.get_token: Distinguish between API Key related failures and others (#697)
Co-authored-by: Michael Anstis <manstis@redhat.com>
1 parent 782431d commit 0db5143

File tree

7 files changed

+89
-25
lines changed

7 files changed

+89
-25
lines changed

ansible_wisdom/ai/api/model_client/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ class WcaTokenFailure(WcaException):
4747
"""An attempt to retrieve a WCA Token failed."""
4848

4949

50+
@dataclass
51+
class WcaTokenFailureApiKeyError(WcaException):
52+
"""An attempt to retrieve a WCA Token failed due to a problem with the provided API Key."""
53+
54+
5055
@dataclass
5156
class WcaCloudflareRejection(WcaException):
5257
"""Cloudflare rejected the request."""

ansible_wisdom/ai/api/model_client/wca_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
ContentMatchResponseChecks,
1010
InferenceContext,
1111
InferenceResponseChecks,
12+
TokenContext,
13+
TokenResponseChecks,
1214
)
1315
from django.apps import apps
1416
from django.conf import settings
@@ -205,6 +207,8 @@ def post_request():
205207

206208
try:
207209
response = post_request()
210+
context = TokenContext(response)
211+
TokenResponseChecks().run_checks(context)
208212
response.raise_for_status()
209213

210214
except HTTPError as e:

ansible_wisdom/ai/api/model_client/wca_utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
from abc import abstractmethod
22
from typing import Generic, TypeVar
33

4-
from .exceptions import WcaCloudflareRejection, WcaEmptyResponse, WcaInvalidModelId
4+
from .exceptions import (
5+
WcaCloudflareRejection,
6+
WcaEmptyResponse,
7+
WcaInvalidModelId,
8+
WcaTokenFailureApiKeyError,
9+
)
510

611
T = TypeVar('T')
712

@@ -23,6 +28,34 @@ def run_checks(self, context: T):
2328
check.check(context)
2429

2530

31+
class TokenContext:
32+
def __init__(self, result):
33+
self.result = result
34+
35+
36+
class TokenResponseChecks(Checks[TokenContext]):
37+
class ResponseStatusCode400Missing(Check[TokenContext]):
38+
def check(self, context: TokenContext):
39+
if context.result.status_code == 400:
40+
if "Property missing or empty" in context.result.json()["errorMessage"]:
41+
raise WcaTokenFailureApiKeyError()
42+
43+
class ResponseStatusCode400NotFound(Check[TokenContext]):
44+
def check(self, context: TokenContext):
45+
if context.result.status_code == 400:
46+
if "Provided API key could not be found" in context.result.json()["errorMessage"]:
47+
raise WcaTokenFailureApiKeyError()
48+
49+
def __init__(self):
50+
super().__init__(
51+
[
52+
# The ordering of these checks is important!
53+
TokenResponseChecks.ResponseStatusCode400Missing(),
54+
TokenResponseChecks.ResponseStatusCode400NotFound(),
55+
]
56+
)
57+
58+
2659
class InferenceContext:
2760
def __init__(self, model_id, result, is_multi_task_prompt):
2861
self.model_id = model_id

ansible_wisdom/ai/api/wca/api_key_views.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from ai.api.aws.exceptions import WcaSecretManagerError
55
from ai.api.aws.wca_secret_manager import Suffixes
6-
from ai.api.model_client.exceptions import WcaTokenFailure
6+
from ai.api.model_client.exceptions import WcaTokenFailureApiKeyError
77
from ai.api.permissions import (
88
AcceptedTermsPermission,
99
IsOrganisationAdministrator,
@@ -140,18 +140,18 @@ def post(self, request, *args, **kwargs):
140140
logger.info(e, exc_info=True)
141141
return Response(status=HTTP_400_BAD_REQUEST)
142142

143-
except WcaTokenFailure as e:
143+
except WcaTokenFailureApiKeyError as e:
144144
exception = e
145145
logger.info(
146146
f"An error occurred trying to retrieve a WCA Token for Organisation '{org_id}'.",
147147
exc_info=True,
148148
)
149149
return Response(status=HTTP_400_BAD_REQUEST)
150150

151-
except WcaSecretManagerError as e:
151+
except Exception as e:
152152
exception = e
153153
logger.exception(e)
154-
raise ServiceUnavailable
154+
raise ServiceUnavailable(cause=e)
155155

156156
finally:
157157
duration = round((time.time() - start_time) * 1000, 2)
@@ -199,23 +199,25 @@ def get(self, request, *args, **kwargs):
199199
model_mesh_client = apps.get_app_config("ai").wca_client
200200
secret_manager = apps.get_app_config("ai").get_wca_secret_manager()
201201
api_key = secret_manager.get_secret(org_id, Suffixes.API_KEY)
202-
token = model_mesh_client.get_token(api_key['SecretString'])
202+
if api_key is None:
203+
return Response(status=HTTP_400_BAD_REQUEST)
204+
token = model_mesh_client.get_token(api_key)
203205
if token is None:
204206
return Response(status=HTTP_400_BAD_REQUEST)
205207

206-
except WcaSecretManagerError as e:
207-
exception = e
208-
logger.exception(e)
209-
raise ServiceUnavailable
210-
211-
except WcaTokenFailure as e:
208+
except WcaTokenFailureApiKeyError as e:
212209
exception = e
213210
logger.info(
214211
f"An error occurred trying to retrieve a WCA Token for Organisation '{org_id}'.",
215212
exc_info=True,
216213
)
217214
return Response(status=HTTP_400_BAD_REQUEST)
218215

216+
except Exception as e:
217+
exception = e
218+
logger.exception(e)
219+
raise ServiceUnavailable(cause=e)
220+
219221
finally:
220222
duration = round((time.time() - start_time) * 1000, 2)
221223
event = {

ansible_wisdom/ai/api/wca/model_id_views.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,6 @@ def do_validated_operation(request, api_key_provider, model_id_provider, on_succ
240240
logger.info(e, exc_info=True)
241241
raise WcaBadRequestException(cause=e)
242242

243-
except WcaSecretManagerError as e:
244-
exception = e
245-
logger.exception(e)
246-
return Response(status=HTTP_500_INTERNAL_SERVER_ERROR)
247-
248243
except Exception as e:
249244
exception = e
250245
logger.exception(e)

ansible_wisdom/ai/api/wca/tests/test_api_key_views.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ai.api.aws.exceptions import WcaSecretManagerError
66
from ai.api.aws.wca_secret_manager import Suffixes, WcaSecretManager
7-
from ai.api.model_client.exceptions import WcaTokenFailure
7+
from ai.api.model_client.exceptions import WcaTokenFailure, WcaTokenFailureApiKeyError
88
from ai.api.model_client.wca_client import WCAClient
99
from ai.api.permissions import (
1010
AcceptedTermsPermission,
@@ -194,7 +194,9 @@ def test_set_key_with_invalid_value(self, *args):
194194
self.mock_secret_manager.get_secret.assert_called_with('123', Suffixes.API_KEY)
195195

196196
# Set Key
197-
self.mock_wca_client.get_token.side_effect = WcaTokenFailure('Something went wrong')
197+
self.mock_wca_client.get_token.side_effect = WcaTokenFailureApiKeyError(
198+
'Something went wrong'
199+
)
198200
with self.assertLogs(logger='root', level='DEBUG') as log:
199201
r = self.client.post(
200202
reverse('wca_api_key'),
@@ -203,7 +205,7 @@ def test_set_key_with_invalid_value(self, *args):
203205
)
204206
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)
205207
self.mock_secret_manager.save_secret.assert_not_called()
206-
_assert_segment_log(self, log, "modelApiKeySet", "WcaTokenFailure")
208+
_assert_segment_log(self, log, "modelApiKeySet", "WcaTokenFailureApiKeyError")
207209

208210
@override_settings(SEGMENT_WRITE_KEY='DUMMY_KEY_VALUE')
209211
def test_set_key_throws_secret_manager_exception(self, *args):
@@ -222,18 +224,17 @@ def test_set_key_throws_secret_manager_exception(self, *args):
222224
_assert_segment_log(self, log, "modelApiKeySet", "WcaSecretManagerError")
223225

224226
@override_settings(SEGMENT_WRITE_KEY='DUMMY_KEY_VALUE')
225-
def test_set_key_throws_wca_client_exception(self, *args):
227+
def test_set_key_throws_http_exception(self, *args):
226228
self.user.organization_id = '123'
227229
self.client.force_authenticate(user=self.user)
228230
self.mock_wca_client.get_token.side_effect = WcaTokenFailure()
229-
230231
with self.assertLogs(logger='root', level='DEBUG') as log:
231232
r = self.client.post(
232233
reverse('wca_api_key'),
233234
data='{ "key": "a-new-key" }',
234235
content_type='application/json',
235236
)
236-
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)
237+
self.assertEqual(r.status_code, HTTPStatus.SERVICE_UNAVAILABLE)
237238
_assert_segment_log(self, log, "modelApiKeySet", "WcaTokenFailure")
238239

239240
@override_settings(SEGMENT_WRITE_KEY='DUMMY_KEY_VALUE')
@@ -314,13 +315,37 @@ def _test_validate_key_with_valid_value(self, has_seat):
314315
self.assertEqual(r.status_code, HTTPStatus.OK)
315316
_assert_segment_log(self, log, "modelApiKeyValidate", None)
316317

318+
@override_settings(SEGMENT_WRITE_KEY='DUMMY_KEY_VALUE')
319+
def test_validate_key_with_missing_value(self, *args):
320+
self.user.organization_id = '123'
321+
self.client.force_authenticate(user=self.user)
322+
self.mock_secret_manager.get_secret.return_value = None
323+
324+
with self.assertLogs(logger='root', level='DEBUG') as log:
325+
r = self.client.get(reverse('wca_api_key_validator'))
326+
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)
327+
_assert_segment_log(self, log, "modelApiKeyValidate", None)
328+
317329
@override_settings(SEGMENT_WRITE_KEY='DUMMY_KEY_VALUE')
318330
def test_validate_key_with_invalid_value(self, *args):
319331
self.user.organization_id = '123'
320332
self.client.force_authenticate(user=self.user)
321-
self.mock_wca_client.get_token.side_effect = WcaTokenFailure('Something went wrong')
333+
self.mock_wca_client.get_token.side_effect = WcaTokenFailureApiKeyError(
334+
'Something went wrong'
335+
)
322336

323337
with self.assertLogs(logger='root', level='DEBUG') as log:
324338
r = self.client.get(reverse('wca_api_key_validator'))
325339
self.assertEqual(r.status_code, HTTPStatus.BAD_REQUEST)
340+
_assert_segment_log(self, log, "modelApiKeyValidate", "WcaTokenFailureApiKeyError")
341+
342+
@override_settings(SEGMENT_WRITE_KEY='DUMMY_KEY_VALUE')
343+
def test_validate_key_throws_http_exception(self, *args):
344+
self.user.organization_id = '123'
345+
self.client.force_authenticate(user=self.user)
346+
self.mock_wca_client.get_token.side_effect = WcaTokenFailure('Something went wrong')
347+
348+
with self.assertLogs(logger='root', level='DEBUG') as log:
349+
r = self.client.get(reverse('wca_api_key_validator'))
350+
self.assertEqual(r.status_code, HTTPStatus.SERVICE_UNAVAILABLE)
326351
_assert_segment_log(self, log, "modelApiKeyValidate", "WcaTokenFailure")

ansible_wisdom/ai/api/wca/tests/test_model_id_views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def test_set_model_id_throws_secret_manager_exception(self, *args):
231231
data='{ "model_id": "secret_model_id" }',
232232
content_type='application/json',
233233
)
234-
self.assertEqual(r.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
234+
self.assertEqual(r.status_code, HTTPStatus.SERVICE_UNAVAILABLE)
235235
self.assertInLog('ai.api.aws.exceptions.WcaSecretManagerError', log)
236236
_assert_segment_log(self, log, "modelIdSet", "WcaSecretManagerError")
237237

0 commit comments

Comments
 (0)