diff --git a/AUTHORS b/AUTHORS index 02fbc3bb..2511b2ee 100644 --- a/AUTHORS +++ b/AUTHORS @@ -23,3 +23,5 @@ Patches and Suggestions - Wouter Bolsterlee - Michael Davis + + - Vinod Gupta diff --git a/CHANGELOG.md b/CHANGELOG.md index 696a6a96..5793d708 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,10 @@ This project adheres to [Semantic Versioning](http://semver.org/). - Dropped support for python 2.6 and 3.3 [#297][297] +- Audience parameter now supports iterables [#205][205] + ### Fixed + ### Added [v1.5.3][1.5.3] diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 223b22be..ad3ff6ae 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -1,7 +1,6 @@ import binascii import json import warnings - from collections import Mapping from .algorithms import ( diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 5ddc8a30..edef7701 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -1,8 +1,7 @@ import json import warnings - from calendar import timegm -from collections import Mapping +from collections import Iterable, Mapping from datetime import datetime, timedelta from .api_jws import PyJWS @@ -103,8 +102,8 @@ def _validate_claims(self, payload, options, audience=None, issuer=None, if isinstance(leeway, timedelta): leeway = leeway.total_seconds() - if not isinstance(audience, (string_types, type(None))): - raise TypeError('audience must be a string or None') + if not isinstance(audience, (string_types, type(None), Iterable)): + raise TypeError('audience must be a string, iterable, or None') self._validate_required_claims(payload, options) @@ -177,7 +176,11 @@ def _validate_aud(self, payload, audience): raise InvalidAudienceError('Invalid claim format in token') if any(not isinstance(c, string_types) for c in audience_claims): raise InvalidAudienceError('Invalid claim format in token') - if audience not in audience_claims: + + if isinstance(audience, string_types): + audience = [audience] + + if not any(aud in audience_claims for aud in audience): raise InvalidAudienceError('Invalid audience') def _validate_iss(self, payload, issuer): diff --git a/jwt/contrib/algorithms/pycrypto.py b/jwt/contrib/algorithms/pycrypto.py index e6afaa59..e49cdbfe 100644 --- a/jwt/contrib/algorithms/pycrypto.py +++ b/jwt/contrib/algorithms/pycrypto.py @@ -1,7 +1,6 @@ import Crypto.Hash.SHA256 import Crypto.Hash.SHA384 import Crypto.Hash.SHA512 - from Crypto.PublicKey import RSA from Crypto.Signature import PKCS1_v1_5 diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index 4e440bd7..60671a29 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -1,6 +1,5 @@ import json - from decimal import Decimal from jwt.algorithms import Algorithm diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 8ce3f2cc..58b47f2c 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -1,7 +1,6 @@ import json import time - from calendar import timegm from datetime import datetime, timedelta from decimal import Decimal @@ -92,7 +91,7 @@ def test_decode_with_invalid_audience_param_throws_exception(self, jwt): jwt.decode(example_jwt, secret, audience=1) exception = context.value - assert str(exception) == 'audience must be a string or None' + assert str(exception) == 'audience must be a string, iterable, or None' def test_decode_with_nonlist_aud_claim_throws_exception(self, jwt): secret = 'secret' @@ -281,6 +280,23 @@ def test_check_audience_when_valid(self, jwt): token = jwt.encode(payload, 'secret') jwt.decode(token, 'secret', audience='urn:me') + def test_check_audience_list_when_valid(self, jwt): + payload = { + 'some': 'payload', + 'aud': 'urn:me' + } + token = jwt.encode(payload, 'secret') + jwt.decode(token, 'secret', audience=['urn:you', 'urn:me']) + + def test_raise_exception_invalid_audience_list(self, jwt): + payload = { + 'some': 'payload', + 'aud': 'urn:me' + } + token = jwt.encode(payload, 'secret') + with pytest.raises(InvalidAudienceError): + jwt.decode(token, 'secret', audience=['urn:you', 'urn:him']) + def test_check_audience_in_array_when_valid(self, jwt): payload = { 'some': 'payload', diff --git a/tests/utils.py b/tests/utils.py index 79c77b0c..be189f2c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,5 @@ import os import struct - from calendar import timegm from datetime import datetime