Skip to content

Commit

Permalink
fix: fix id_token iam endpoint for non-gdu service credentials (#1506)
Browse files Browse the repository at this point in the history
* fix: fix id_token iam endpoint for non-gdu service credentials

* chore: address comments
  • Loading branch information
arithmetic1728 authored Mar 28, 2024
1 parent 089206e commit 93d681e
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 37 deletions.
21 changes: 18 additions & 3 deletions google/auth/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,23 @@
from google.auth import crypt
from google.auth import exceptions

_IAM_API_ROOT_URI = "https://iamcredentials.googleapis.com/v1"
_SIGN_BLOB_URI = _IAM_API_ROOT_URI + "/projects/-/serviceAccounts/{}:signBlob?alt=json"

_IAM_SCOPE = ["https://www.googleapis.com/auth/iam"]

_IAM_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/projects/-"
+ "/serviceAccounts/{}:generateAccessToken"
)

_IAM_SIGN_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/projects/-"
+ "/serviceAccounts/{}:signBlob"
)

_IAM_IDTOKEN_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/"
+ "projects/-/serviceAccounts/{}:generateIdToken"
)


class Signer(crypt.Signer):
Expand Down Expand Up @@ -67,7 +82,7 @@ def _make_signing_request(self, message):
message = _helpers.to_bytes(message)

method = "POST"
url = _SIGN_BLOB_URI.format(self._service_account_email)
url = _IAM_SIGN_ENDPOINT.format(self._service_account_email)
headers = {"Content-Type": "application/json"}
body = json.dumps(
{"payload": base64.b64encode(message).decode("utf-8")}
Expand Down
29 changes: 7 additions & 22 deletions google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,15 @@
from google.auth import _helpers
from google.auth import credentials
from google.auth import exceptions
from google.auth import iam
from google.auth import jwt
from google.auth import metrics

_IAM_SCOPE = ["https://www.googleapis.com/auth/iam"]

_IAM_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/projects/-"
+ "/serviceAccounts/{}:generateAccessToken"
)

_IAM_SIGN_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/projects/-"
+ "/serviceAccounts/{}:signBlob"
)

_IAM_IDTOKEN_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/"
+ "projects/-/serviceAccounts/{}:generateIdToken"
)

_REFRESH_ERROR = "Unable to acquire impersonated credentials"

_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds

_DEFAULT_TOKEN_URI = "https://oauth2.googleapis.com/token"


def _make_iam_token_request(
request, principal, headers, body, iam_endpoint_override=None
Expand All @@ -83,7 +66,7 @@ def _make_iam_token_request(
`iamcredentials.googleapis.com` is not enabled or the
`Service Account Token Creator` is not assigned
"""
iam_endpoint = iam_endpoint_override or _IAM_ENDPOINT.format(principal)
iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.format(principal)

body = json.dumps(body).encode("utf-8")

Expand Down Expand Up @@ -225,7 +208,9 @@ def __init__(
# added to refresh correctly. User credentials cannot have
# their original scopes modified.
if isinstance(self._source_credentials, credentials.Scoped):
self._source_credentials = self._source_credentials.with_scopes(_IAM_SCOPE)
self._source_credentials = self._source_credentials.with_scopes(
iam._IAM_SCOPE
)
# If the source credential is service account and self signed jwt
# is needed, we need to create a jwt credential inside it
if (
Expand Down Expand Up @@ -290,7 +275,7 @@ def _update_token(self, request):
def sign_bytes(self, message):
from google.auth.transport.requests import AuthorizedSession

iam_sign_endpoint = _IAM_SIGN_ENDPOINT.format(self._target_principal)
iam_sign_endpoint = iam._IAM_SIGN_ENDPOINT.format(self._target_principal)

body = {
"payload": base64.b64encode(message).decode("utf-8"),
Expand Down Expand Up @@ -425,7 +410,7 @@ def with_quota_project(self, quota_project_id):
def refresh(self, request):
from google.auth.transport.requests import AuthorizedSession

iam_sign_endpoint = _IAM_IDTOKEN_ENDPOINT.format(
iam_sign_endpoint = iam._IAM_IDTOKEN_ENDPOINT.format(
self._target_credentials.signer_email
)

Expand Down
11 changes: 5 additions & 6 deletions google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@
_JSON_CONTENT_TYPE = "application/json"
_JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
_REFRESH_GRANT_TYPE = "refresh_token"
_IAM_IDTOKEN_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/"
+ "projects/-/serviceAccounts/{}:generateIdToken"
)


def _handle_error_response(response_data, retryable_error):
Expand Down Expand Up @@ -328,12 +324,15 @@ def jwt_grant(request, token_uri, assertion, can_retry=True):
return access_token, expiry, response_data


def call_iam_generate_id_token_endpoint(request, signer_email, audience, access_token):
def call_iam_generate_id_token_endpoint(
request, iam_id_token_endpoint, signer_email, audience, access_token
):
"""Call iam.generateIdToken endpoint to get ID token.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
iam_id_token_endpoint (str): The IAM ID token endpoint to use.
signer_email (str): The signer email used to form the IAM
generateIdToken endpoint.
audience (str): The audience for the ID token.
Expand All @@ -346,7 +345,7 @@ def call_iam_generate_id_token_endpoint(request, signer_email, audience, access_

response_data = _token_endpoint_request(
request,
_IAM_IDTOKEN_ENDPOINT.format(signer_email),
iam_id_token_endpoint.format(signer_email),
body,
access_token=access_token,
use_json=True,
Expand Down
7 changes: 6 additions & 1 deletion google/oauth2/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from google.auth import _service_account_info
from google.auth import credentials
from google.auth import exceptions
from google.auth import iam
from google.auth import jwt
from google.auth import metrics
from google.oauth2 import _client
Expand Down Expand Up @@ -595,8 +596,11 @@ def __init__(
self._universe_domain = credentials.DEFAULT_UNIVERSE_DOMAIN
else:
self._universe_domain = universe_domain
self._iam_id_token_endpoint = iam._IAM_IDTOKEN_ENDPOINT.replace(
"googleapis.com", self._universe_domain
)

if universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN:
if self._universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN:
self._use_iam_endpoint = True

if additional_claims is not None:
Expand Down Expand Up @@ -792,6 +796,7 @@ def _refresh_with_iam_endpoint(self, request):
jwt_credentials.refresh(request)
self.token, self.expiry = _client.call_iam_generate_id_token_endpoint(
request,
self._iam_id_token_endpoint,
self.signer_email,
self._target_audience,
jwt_credentials.token.decode(),
Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def test_with_target_audience_integration(self):
responses.add(
responses.POST,
"https://iamcredentials.googleapis.com/v1/projects/-/"
"serviceAccounts/service-account@example.com:signBlob?alt=json",
"serviceAccounts/service-account@example.com:signBlob",
status=200,
content_type="application/json",
json={"keyId": "some-key-id", "signedBlob": signature},
Expand Down Expand Up @@ -657,7 +657,7 @@ def test_with_quota_project_integration(self):
responses.add(
responses.POST,
"https://iamcredentials.googleapis.com/v1/projects/-/"
"serviceAccounts/service-account@example.com:signBlob?alt=json",
"serviceAccounts/service-account@example.com:signBlob",
status=200,
content_type="application/json",
json={"keyId": "some-key-id", "signedBlob": signature},
Expand Down
13 changes: 11 additions & 2 deletions tests/oauth2/test__client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.auth import _helpers
from google.auth import crypt
from google.auth import exceptions
from google.auth import iam
from google.auth import jwt
from google.auth import transport
from google.oauth2 import _client
Expand Down Expand Up @@ -318,7 +319,11 @@ def test_call_iam_generate_id_token_endpoint():
request = make_request({"token": id_token})

token, expiry = _client.call_iam_generate_id_token_endpoint(
request, "fake_email", "fake_audience", "fake_access_token"
request,
iam._IAM_IDTOKEN_ENDPOINT,
"fake_email",
"fake_audience",
"fake_access_token",
)

assert (
Expand Down Expand Up @@ -351,7 +356,11 @@ def test_call_iam_generate_id_token_endpoint_no_id_token():

with pytest.raises(exceptions.RefreshError) as excinfo:
_client.call_iam_generate_id_token_endpoint(
request, "fake_email", "fake_audience", "fake_access_token"
request,
iam._IAM_IDTOKEN_ENDPOINT,
"fake_email",
"fake_audience",
"fake_access_token",
)
assert excinfo.match("No ID token in response")

Expand Down
29 changes: 28 additions & 1 deletion tests/oauth2/test_service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.auth import _helpers
from google.auth import crypt
from google.auth import exceptions
from google.auth import iam
from google.auth import jwt
from google.auth import transport
from google.auth.credentials import DEFAULT_UNIVERSE_DOMAIN
Expand Down Expand Up @@ -771,10 +772,36 @@ def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint):
)
request = mock.Mock()
credentials.refresh(request)
req, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[
req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[
0
]
assert req == request
assert iam_endpoint == iam._IAM_IDTOKEN_ENDPOINT
assert signer_email == "service-account@example.com"
assert target_audience == "https://example.com"
decoded_access_token = jwt.decode(access_token, verify=False)
assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam"

@mock.patch(
"google.oauth2._client.call_iam_generate_id_token_endpoint", autospec=True
)
def test_refresh_iam_flow_non_gdu(self, call_iam_generate_id_token_endpoint):
credentials = self.make_credentials(universe_domain="fake-universe")
token = "id_token"
call_iam_generate_id_token_endpoint.return_value = (
token,
_helpers.utcnow() + datetime.timedelta(seconds=500),
)
request = mock.Mock()
credentials.refresh(request)
req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[
0
]
assert req == request
assert (
iam_endpoint
== "https://iamcredentials.fake-universe/v1/projects/-/serviceAccounts/{}:generateIdToken"
)
assert signer_email == "service-account@example.com"
assert target_audience == "https://example.com"
decoded_access_token = jwt.decode(access_token, verify=False)
Expand Down

0 comments on commit 93d681e

Please sign in to comment.