Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove one-time token behavior of JWT Credentials #117

Merged
merged 4 commits into from
Feb 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 19 additions & 76 deletions google/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@
import datetime
import json

from six.moves import urllib

from google.auth import _helpers
from google.auth import _service_account_info
from google.auth import credentials
Expand Down Expand Up @@ -246,11 +244,7 @@ class Credentials(credentials.Signing,
"""Credentials that use a JWT as the bearer token.

These credentials require an "audience" claim. This claim identifies the
intended recipient of the bearer token. You can set the audience when
you construct these credentials, however, these credentials can also set
the audience claim automatically if not specified. In this case, whenever
a request is made the credentials will automatically generate a one-time
JWT with the request URI as the audience.
intended recipient of the bearer token.

The constructor arguments determine the claims for the JWT that is
sent with requests. Usually, you'll construct these credentials with
Expand All @@ -260,13 +254,15 @@ class Credentials(credentials.Signing,
JSON file::

credentials = jwt.Credentials.from_service_account_file(
'service-account.json')
'service-account.json',
audience='https://speech.googleapis.com')

This comment was marked as spam.

This comment was marked as spam.


If you already have the service account file loaded and parsed::

service_account_info = json.load(open('service_account.json'))
credentials = jwt.Credentials.from_service_account_info(
service_account_info)
service_account_info,
audience='https://speech.googleapis.com')

Both helper methods pass on arguments to the constructor, so you can
specify the JWT claims::
Expand All @@ -280,7 +276,10 @@ class Credentials(credentials.Signing,
:class:`~google.auth.crypt.Signer` instance::

credentials = jwt.Credentials(
signer, issuer='your-issuer', subject='your-subject')
signer,
issuer='your-issuer',
subject='your-subject',
audience=''https://speech.googleapis.com'')

The claims are considered immutable. If you want to modify the claims,
you can easily create another instance using :meth:`with_claims`::
Expand All @@ -289,7 +288,7 @@ class Credentials(credentials.Signing,
audience='https://vision.googleapis.com')
"""

def __init__(self, signer, issuer=None, subject=None, audience=None,
def __init__(self, signer, issuer, subject, audience,

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

additional_claims=None,
token_lifetime=_DEFAULT_TOKEN_LIFETIME_SECS):
"""
Expand All @@ -298,8 +297,7 @@ def __init__(self, signer, issuer=None, subject=None, audience=None,
issuer (str): The `iss` claim.
subject (str): The `sub` claim.
audience (str): the `aud` claim. The intended audience for the
credentials. If not specified, a new JWT will be generated for
every request and will use the request URI as the audience.
credentials.
additional_claims (Mapping[str, str]): Any additional claims for
the JWT payload.
token_lifetime (int): The amount of time in seconds for
Expand Down Expand Up @@ -334,7 +332,8 @@ def _from_signer_and_info(cls, signer, info, **kwargs):
ValueError: If the info is not in the expected format.
"""
kwargs.setdefault('subject', info['client_email'])
return cls(signer, issuer=info['client_email'], **kwargs)
kwargs.setdefault('issuer', info['client_email'])
return cls(signer, **kwargs)

@classmethod
def from_service_account_info(cls, info, **kwargs):
Expand Down Expand Up @@ -381,9 +380,8 @@ def with_claims(self, issuer=None, subject=None, audience=None,
claim will be used.
subject (str): The `sub` claim. If unspecified the current subject
claim will be used.
audience (str): the `aud` claim. If not specified, a new
JWT will be generated for every request and will use
the request URI as the audience.
audience (str): the `aud` claim. If unspecified the current

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

audience claim will be used.
additional_claims (Mapping[str, str]): Any additional claims for
the JWT payload. This will be merged with the current
additional claims.
Expand All @@ -399,12 +397,9 @@ def with_claims(self, issuer=None, subject=None, audience=None,
additional_claims=self._additional_claims.copy().update(
additional_claims or {}))

def _make_jwt(self, audience=None):
def _make_jwt(self):
"""Make a signed JWT.

Args:
audience (str): Overrides the instance's current audience claim.

Returns:
Tuple[bytes, datetime]: The encoded JWT and the expiration.
"""
Expand All @@ -414,10 +409,10 @@ def _make_jwt(self, audience=None):

payload = {
'iss': self._issuer,
'sub': self._subject or self._issuer,
'sub': self._subject,
'iat': _helpers.datetime_to_secs(now),
'exp': _helpers.datetime_to_secs(expiry),
'aud': audience or self._audience,
'aud': self._audience,
}

payload.update(self._additional_claims)
Expand All @@ -426,22 +421,6 @@ def _make_jwt(self, audience=None):

return jwt, expiry

def _make_one_time_jwt(self, uri):
"""Makes a one-off JWT with the URI as the audience.

Args:
uri (str): The request URI.

Returns:
bytes: The encoded JWT.
"""
parts = urllib.parse.urlsplit(uri)
# Strip query string and fragment
audience = urllib.parse.urlunsplit(
(parts.scheme, parts.netloc, parts.path, None, None))
token, _ = self._make_jwt(audience=audience)
return token

def refresh(self, request):
"""Refreshes the access token.

Expand All @@ -452,15 +431,8 @@ def refresh(self, request):
# (pylint doesn't correctly recognize overridden methods.)
self.token, self.expiry = self._make_jwt()

@_helpers.copy_docstring(credentials.Signing)

This comment was marked as spam.

This comment was marked as spam.

def sign_bytes(self, message):
"""Signs the given message.

Args:
message (bytes): The message to sign.

Returns:
bytes: The message signature.
"""
return self._signer.sign(message)

@property
Expand All @@ -472,32 +444,3 @@ def signer_email(self):
@_helpers.copy_docstring(credentials.Signing)
def signer(self):
return self._signer

def before_request(self, request, method, url, headers):
"""Performs credential-specific before request logic.

If an audience is specified it will refresh the credentials if
necessary. If no audience is specified it will generate a one-time
token for the request URI. In either case, it will set the
authorization header in headers to the token.

Args:
request (Any): Unused.
method (str): The request's HTTP method.
url (str): The request's URI.
headers (Mapping): The request's headers.
"""
# pylint: disable=unused-argument
# (pylint doesn't correctly recognize overridden methods.)

# If this set of credentials has a pre-set audience, just ensure that
# there is a valid token and apply the auth headers.
if self._audience:
if not self.valid:
self.refresh(request)
self.apply(headers)
# Otherwise, generate a one-time token using the URL
# (without the query string and fragment) as the audience.
else:
token = self._make_one_time_jwt(url)
self.apply(headers, token=token)
9 changes: 7 additions & 2 deletions google/oauth2/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def from_service_account_file(cls, filename, **kwargs):
filename, require=['client_email', 'token_uri'])
return cls._from_signer_and_info(signer, info, **kwargs)

def to_jwt_credentials(self):
def to_jwt_credentials(self, audience):
"""Creates a :class:`google.auth.jwt.Credentials` instance from this
instance.

Expand All @@ -223,13 +223,18 @@ def to_jwt_credentials(self):
jwt_creds = jwt.Credentials.from_service_account_file(
'service_account.json')

Args:
audience (str): the `aud` claim. The intended audience for the

This comment was marked as spam.

credentials.

Returns:
google.auth.jwt.Credentials: A new Credentials instance.
"""
return jwt.Credentials(
self._signer,
issuer=self._service_account_email,
subject=self._service_account_email)
subject=self._service_account_email,
audience=audience)

@property
def service_account_email(self):
Expand Down
5 changes: 4 additions & 1 deletion system_tests/test_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def test_grpc_request_with_regular_credentials(http_request):

def test_grpc_request_with_jwt_credentials(http_request):
credentials, project_id = google.auth.default()
credentials = credentials.to_jwt_credentials()
audience = 'https://{}/google.pubsub.v1.Publisher'.format(
publisher_client.PublisherClient.SERVICE_ADDRESS)
credentials = credentials.to_jwt_credentials(
audience=audience)

channel = google.auth.transport.grpc.secure_authorized_channel(
credentials,
Expand Down
6 changes: 4 additions & 2 deletions tests/oauth2/test_service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ def test_from_service_account_file_args(self):
assert credentials._additional_claims == additional_claims

def test_to_jwt_credentials(self):
jwt_from_svc = self.credentials.to_jwt_credentials()
jwt_from_svc = self.credentials.to_jwt_credentials(
audience=mock.sentinel.audience)
jwt_from_info = jwt.Credentials.from_service_account_info(
SERVICE_ACCOUNT_INFO)
SERVICE_ACCOUNT_INFO,
audience=mock.sentinel.audience)

assert isinstance(jwt_from_svc, jwt.Credentials)
assert jwt_from_svc._signer.key_id == jwt_from_info._signer.key_id
Expand Down
51 changes: 24 additions & 27 deletions tests/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,20 @@ class TestCredentials:
@pytest.fixture(autouse=True)
def credentials_fixture(self, signer):
self.credentials = jwt.Credentials(
signer, self.SERVICE_ACCOUNT_EMAIL)
signer, self.SERVICE_ACCOUNT_EMAIL, self.SERVICE_ACCOUNT_EMAIL,
self.AUDIENCE)

def test_from_service_account_info(self):
with open(SERVICE_ACCOUNT_JSON_FILE, 'r') as fh:
info = json.load(fh)

credentials = jwt.Credentials.from_service_account_info(info)
credentials = jwt.Credentials.from_service_account_info(
info, audience=self.AUDIENCE)

assert credentials._signer.key_id == info['private_key_id']
assert credentials._issuer == info['client_email']
assert credentials._subject == info['client_email']
assert credentials._audience == self.AUDIENCE

def test_from_service_account_info_args(self):
info = SERVICE_ACCOUNT_INFO.copy()
Expand All @@ -235,11 +238,12 @@ def test_from_service_account_file(self):
info = SERVICE_ACCOUNT_INFO.copy()

credentials = jwt.Credentials.from_service_account_file(
SERVICE_ACCOUNT_JSON_FILE)
SERVICE_ACCOUNT_JSON_FILE, audience=self.AUDIENCE)

assert credentials._signer.key_id == info['private_key_id']
assert credentials._issuer == info['client_email']
assert credentials._subject == info['client_email']
assert credentials._audience == self.AUDIENCE

def test_from_service_account_file_args(self):
info = SERVICE_ACCOUNT_INFO.copy()
Expand All @@ -259,6 +263,18 @@ def test_default_state(self):
# Expiration hasn't been set yet
assert not self.credentials.expired

def test_with_claims(self):
new_audience = 'new_audience'
new_credentials = self.credentials.with_claims(
audience=new_audience)

assert new_credentials._signer == self.credentials._signer
assert new_credentials._issuer == self.credentials._issuer
assert new_credentials._subject == self.credentials._subject
assert new_credentials._audience == new_audience
assert (new_credentials._additional_claims ==
self.credentials._additional_claims)

def test_sign_bytes(self):
to_sign = b'123'
signature = self.credentials.sign_bytes(to_sign)
Expand Down Expand Up @@ -292,43 +308,24 @@ def test_expired(self):
now.return_value = self.credentials.expiry + one_day
assert self.credentials.expired

def test_before_request_one_time_token(self):
def test_before_request(self):
headers = {}

self.credentials.refresh(None)
self.credentials.before_request(
mock.Mock(), 'GET', 'http://example.com?a=1#3', headers)

header_value = headers['authorization']
_, token = header_value.split(' ')

# This should be a one-off token, so it shouldn't be the same as the
# credentials' stored token.
assert token != self.credentials.token

payload = self._verify_token(token)
assert payload['aud'] == 'http://example.com'

def test_before_request_with_preset_audience(self):
headers = {}

credentials = self.credentials.with_claims(audience=self.AUDIENCE)
credentials.refresh(None)
credentials.before_request(
None, 'GET', 'http://example.com?a=1#3', headers)

header_value = headers['authorization']
_, token = header_value.split(' ')

# Since the audience is set, it should use the existing token.
assert token.encode('utf-8') == credentials.token
assert token.encode('utf-8') == self.credentials.token

payload = self._verify_token(token)
assert payload['aud'] == self.AUDIENCE

def test_before_request_refreshes(self):
credentials = self.credentials.with_claims(audience=self.AUDIENCE)
assert not credentials.valid
credentials.before_request(
assert not self.credentials.valid
self.credentials.before_request(
None, 'GET', 'http://example.com?a=1#3', {})
assert credentials.valid
assert self.credentials.valid