diff --git a/airflow/providers/hashicorp/_internal_client/vault_client.py b/airflow/providers/hashicorp/_internal_client/vault_client.py index f8e5c254d490c..1b31af204ced7 100644 --- a/airflow/providers/hashicorp/_internal_client/vault_client.py +++ b/airflow/providers/hashicorp/_internal_client/vault_client.py @@ -146,6 +146,13 @@ def __init__( raise VaultError("The 'radius' authentication type requires 'radius_host'") if not radius_secret: raise VaultError("The 'radius' authentication type requires 'radius_secret'") + if auth_type == "gcp": + if not gcp_scopes: + raise VaultError("The 'gcp' authentication type requires 'gcp_scopes'") + if not role_id: + raise VaultError("The 'gcp' authentication type requires 'role_id'") + if not gcp_key_path and not gcp_keyfile_dict: + raise VaultError("The 'gcp' authentication type requires 'gcp_key_path' or 'gcp_keyfile_dict'") self.kv_engine_version = kv_engine_version or 2 self.url = url @@ -291,13 +298,47 @@ def _auth_gcp(self, _client: hvac.Client) -> None: ) scopes = _get_scopes(self.gcp_scopes) - credentials, _ = get_credentials_and_project_id( + credentials, project_id = get_credentials_and_project_id( key_path=self.gcp_key_path, keyfile_dict=self.gcp_keyfile_dict, scopes=scopes ) + + import time + import json + import googleapiclient + + with open(self.gcp_key_path, 'r') as f: + creds = json.load(f) + service_account = creds['client_email'] + + # Generate a payload for subsequent "signJwt()" call + # Reference: https://google-auth.readthedocs.io/en/latest/reference/google.auth.jwt.html#google.auth.jwt.Credentials + now = int(time.time()) + expires = now + 900 # 15 mins in seconds, can't be longer. + payload = { + 'iat': now, + 'exp': expires, + 'sub': credentials, + 'aud': f'vault/{self.role_id}' + } + body = {'payload': json.dumps(payload)} + name = f'projects/{project_id}/serviceAccounts/{service_account}' + + # Perform the GCP API call + iam = googleapiclient.discovery.build('iam', 'v1', credentials=credentials) + request = iam.projects().serviceAccounts().signJwt(name=name, body=body) + resp = request.execute() + jwt = resp['signedJwt'] + if self.auth_mount_point: - _client.auth.gcp.configure(credentials=credentials, mount_point=self.auth_mount_point) + _client.auth.gcp.login( + role=self.role_id, + jwt=jwt, + mount_point=self.auth_mount_point) else: - _client.auth.gcp.configure(credentials=credentials) + _client.auth.gcp.login( + role=self.role_id, + jwt=jwt) + def _auth_azure(self, _client: hvac.Client) -> None: if self.auth_mount_point: diff --git a/tests/providers/hashicorp/_internal_client/test_vault_client.py b/tests/providers/hashicorp/_internal_client/test_vault_client.py index bb9a53ceb5327..e79825382cdee 100644 --- a/tests/providers/hashicorp/_internal_client/test_vault_client.py +++ b/tests/providers/hashicorp/_internal_client/test_vault_client.py @@ -251,8 +251,9 @@ def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes): key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"] ) mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with( - credentials="credentials", + client.auth.gcp.login.assert_called_with( + role="TODO", + jwt="TODO", ) client.is_authenticated.assert_called_with() assert 2 == vault_client.kv_engine_version @@ -280,7 +281,11 @@ def test_gcp_different_auth_mount_point(self, mock_hvac, mock_get_credentials, m key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"] ) mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with(credentials="credentials", mount_point="other") + client.auth.gcp.login.assert_called_with( + role="TODO", + jwt="TODO", + mount_point="other" + ) client.is_authenticated.assert_called_with() assert 2 == vault_client.kv_engine_version @@ -306,8 +311,9 @@ def test_gcp_dict(self, mock_hvac, mock_get_credentials, mock_get_scopes): key_path=None, keyfile_dict={"key": "value"}, scopes=["scope1", "scope2"] ) mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with( - credentials="credentials", + client.auth.gcp.login.assert_called_with( + role="TODO", + jwt="TODO", ) client.is_authenticated.assert_called_with() assert 2 == vault_client.kv_engine_version