Skip to content

Commit 95dfce8

Browse files
feat: add experimental GDCH support
1 parent de1fd41 commit 95dfce8

File tree

7 files changed

+574
-19
lines changed

7 files changed

+574
-19
lines changed

google/auth/_default.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
_SERVICE_ACCOUNT_TYPE = "service_account"
3737
_EXTERNAL_ACCOUNT_TYPE = "external_account"
3838
_IMPERSONATED_SERVICE_ACCOUNT_TYPE = "impersonated_service_account"
39+
_GDCH_SERVICE_ACCOUNT_TYPE = "gdch_service_account"
3940
_VALID_TYPES = (
4041
_AUTHORIZED_USER_TYPE,
4142
_SERVICE_ACCOUNT_TYPE,
4243
_EXTERNAL_ACCOUNT_TYPE,
4344
_IMPERSONATED_SERVICE_ACCOUNT_TYPE,
45+
_GDCH_SERVICE_ACCOUNT_TYPE,
4446
)
4547

4648
# Help message when no credentials can be found.
@@ -158,6 +160,8 @@ def _load_credentials_from_info(
158160
credentials, project_id = _get_impersonated_service_account_credentials(
159161
filename, info, scopes
160162
)
163+
elif credential_type == _GDCH_SERVICE_ACCOUNT_TYPE:
164+
credentials, project_id = _get_gdch_service_account_credentials(info)
161165
else:
162166
raise exceptions.DefaultCredentialsError(
163167
"The file {file} does not have a valid type. "
@@ -421,6 +425,36 @@ def _get_impersonated_service_account_credentials(filename, info, scopes):
421425
return credentials, None
422426

423427

428+
def _get_gdch_service_account_credentials(info):
429+
from google.oauth2 import gdch_credentials
430+
431+
k8s_ca_cert_path = info.get("k8s_ca_cert_path")
432+
k8s_cert_path = info.get("k8s_cert_path")
433+
k8s_key_path = info.get("k8s_key_path")
434+
k8s_token_endpoint = info.get("k8s_token_endpoint")
435+
ais_ca_cert_path = info.get("ais_ca_cert_path")
436+
ais_token_endpoint = info.get("ais_token_endpoint")
437+
438+
format_version = info.get("format_version")
439+
if format_version != "v1":
440+
raise exceptions.DefaultCredentialsError(
441+
"format_version is not provided or unsupported. Supported version is: v1"
442+
)
443+
444+
return (
445+
gdch_credentials.ServiceAccountCredentials(
446+
k8s_ca_cert_path,
447+
k8s_cert_path,
448+
k8s_key_path,
449+
k8s_token_endpoint,
450+
ais_ca_cert_path,
451+
ais_token_endpoint,
452+
None,
453+
),
454+
None,
455+
)
456+
457+
424458
def _apply_quota_project_id(credentials, quota_project_id):
425459
if quota_project_id:
426460
credentials = credentials.with_quota_project(quota_project_id)
@@ -456,6 +490,11 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
456490
endpoint.
457491
The project ID returned in this case is the one corresponding to the
458492
underlying workload identity pool resource if determinable.
493+
494+
If the environment variable is set to the path of a valid GDCH service
495+
account JSON file (`Google Distributed Cloud Hosted`_), then a GDCH
496+
credential will be returned. The project ID returned is None unless it
497+
is set via `GOOGLE_CLOUD_PROJECT` environment variable.
459498
2. If the `Google Cloud SDK`_ is installed and has application default
460499
credentials set they are loaded and returned.
461500
@@ -490,6 +529,8 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
490529
.. _Metadata Service: https://cloud.google.com/compute/docs\
491530
/storing-retrieving-metadata
492531
.. _Cloud Run: https://cloud.google.com/run
532+
.. _Google Distributed Cloud Hosted: https://cloud.google.com/blog/topics\
533+
/hybrid-cloud/announcing-google-distributed-cloud-edge-and-hosted
493534
494535
Example::
495536

google/oauth2/_client.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ def _handle_error_response(response_data):
4444
"""Translates an error response into an exception.
4545
4646
Args:
47-
response_data (Mapping): The decoded response data.
47+
response_data (Mapping | str): The decoded response data.
4848
4949
Raises:
5050
google.auth.exceptions.RefreshError: The errors contained in response_data.
5151
"""
52+
if isinstance(response_data, six.string_types):
53+
raise exceptions.RefreshError(response_data)
5254
try:
5355
error_details = "{}: {}".format(
5456
response_data["error"], response_data.get("error_description")
@@ -79,7 +81,13 @@ def _parse_expiry(response_data):
7981

8082

8183
def _token_endpoint_request_no_throw(
82-
request, token_uri, body, access_token=None, use_json=False
84+
request,
85+
token_uri,
86+
body,
87+
access_token=None,
88+
use_json=False,
89+
expected_status_code=http_client.OK,
90+
**kwargs
8391
):
8492
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
8593
This function doesn't throw on response errors.
@@ -93,6 +101,10 @@ def _token_endpoint_request_no_throw(
93101
access_token (Optional(str)): The access token needed to make the request.
94102
use_json (Optional(bool)): Use urlencoded format or json format for the
95103
content type. The default value is False.
104+
expected_status_code (Optional(int)): The expected the status code of
105+
the token response. The default value is 200. We may expect other
106+
status code like 201 for GDCH credentials.
107+
kwargs: Additional arguments passed on to the request method.
96108
97109
Returns:
98110
Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
@@ -112,32 +124,46 @@ def _token_endpoint_request_no_throw(
112124
# retry to fetch token for maximum of two times if any internal failure
113125
# occurs.
114126
while True:
115-
response = request(method="POST", url=token_uri, headers=headers, body=body)
127+
response = request(
128+
method="POST", url=token_uri, headers=headers, body=body, **kwargs
129+
)
116130
response_body = (
117131
response.data.decode("utf-8")
118132
if hasattr(response.data, "decode")
119133
else response.data
120134
)
121-
response_data = json.loads(response_body)
122135

123-
if response.status == http_client.OK:
136+
if response.status == expected_status_code:
137+
# response_body should be a JSON
138+
response_data = json.loads(response_body)
124139
break
125140
else:
126-
error_desc = response_data.get("error_description") or ""
127-
error_code = response_data.get("error") or ""
128-
if (
129-
any(e == "internal_failure" for e in (error_code, error_desc))
130-
and retry < 1
131-
):
132-
retry += 1
133-
continue
134-
return response.status == http_client.OK, response_data
135-
136-
return response.status == http_client.OK, response_data
141+
# For a failed response, response_body could be a string
142+
try:
143+
response_data = json.loads(response_body)
144+
error_desc = response_data.get("error_description") or ""
145+
error_code = response_data.get("error") or ""
146+
if (
147+
any(e == "internal_failure" for e in (error_code, error_desc))
148+
and retry < 1
149+
):
150+
retry += 1
151+
continue
152+
except ValueError:
153+
response_data = response_body
154+
return response.status == expected_status_code, response_data
155+
156+
return response.status == expected_status_code, response_data
137157

138158

139159
def _token_endpoint_request(
140-
request, token_uri, body, access_token=None, use_json=False
160+
request,
161+
token_uri,
162+
body,
163+
access_token=None,
164+
use_json=False,
165+
expected_status_code=http_client.OK,
166+
**kwargs
141167
):
142168
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
143169
@@ -150,6 +176,10 @@ def _token_endpoint_request(
150176
access_token (Optional(str)): The access token needed to make the request.
151177
use_json (Optional(bool)): Use urlencoded format or json format for the
152178
content type. The default value is False.
179+
expected_status_code (Optional(int)): The expected the status code of
180+
the token response. The default value is 200. We may expect other
181+
status code like 201 for GDCH credentials.
182+
kwargs: Additional arguments passed on to the request method.
153183
154184
Returns:
155185
Mapping[str, str]: The JSON-decoded response data.
@@ -159,7 +189,13 @@ def _token_endpoint_request(
159189
an error.
160190
"""
161191
response_status_ok, response_data = _token_endpoint_request_no_throw(
162-
request, token_uri, body, access_token=access_token, use_json=use_json
192+
request,
193+
token_uri,
194+
body,
195+
access_token=access_token,
196+
use_json=use_json,
197+
expected_status_code=expected_status_code,
198+
**kwargs
163199
)
164200
if not response_status_ok:
165201
_handle_error_response(response_data)

0 commit comments

Comments
 (0)