diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index ee9bf622..474edf9d 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -87,6 +87,9 @@ def _validate_claims(self, payload, options, audience=None, issuer=None, if isinstance(leeway, timedelta): leeway = timedelta_total_seconds(leeway) + if not isinstance(audience, (string_types, type(None), list)): + raise TypeError('audience must be a string, list of strings, or None') + self._validate_required_claims(payload, options) now = timegm(datetime.utcnow().utctimetuple()) @@ -162,9 +165,6 @@ def _validate_aud(self, payload, audience): if isinstance(audience, string_types): audience = [audience] - if not isinstance(audience, list): - raise InvalidAudienceError('Invalid audience format') - for aud in audience: if aud in audience_claims: return diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index f1151c91..adba0914 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -92,7 +92,7 @@ def test_decode_with_invalid_audience_param_throws_exception(self, jwt): jwt.decode(example_jwt, secret, audience=1) exception = context.value - assert str(exception) == 'audience must be a string or None' + assert str(exception) == 'audience must be a string, list of strings, or None' def test_decode_with_nonlist_aud_claim_throws_exception(self, jwt): secret = 'secret' @@ -289,7 +289,6 @@ def test_check_audience_list_when_valid(self, jwt): token = jwt.encode(payload, 'secret') jwt.decode(token, 'secret', audience=['urn:you', 'urn:me']) - def test_raise_exception_invalid_audience_list(self, jwt): payload = { 'some': 'payload',