diff --git a/airflow/providers/hashicorp/_internal_client/vault_client.py b/airflow/providers/hashicorp/_internal_client/vault_client.py index 98bcc71405d29..9337a71dd7abb 100644 --- a/airflow/providers/hashicorp/_internal_client/vault_client.py +++ b/airflow/providers/hashicorp/_internal_client/vault_client.py @@ -151,6 +151,13 @@ def __init__( raise VaultError("The 'radius' authentication type requires 'radius_host'") if not radius_secret: raise VaultError("The 'radius' authentication type requires 'radius_secret'") + if auth_type == "gcp": + if not gcp_scopes: + raise VaultError("The 'gcp' authentication type requires 'gcp_scopes'") + if not role_id: + raise VaultError("The 'gcp' authentication type requires 'role_id'") + if not gcp_key_path and not gcp_keyfile_dict: + raise VaultError("The 'gcp' authentication type requires 'gcp_key_path' or 'gcp_keyfile_dict'") self.kv_engine_version = kv_engine_version or 2 self.url = url @@ -299,13 +306,47 @@ def _auth_gcp(self, _client: hvac.Client) -> None: ) scopes = _get_scopes(self.gcp_scopes) - credentials, _ = get_credentials_and_project_id( + credentials, project_id = get_credentials_and_project_id( key_path=self.gcp_key_path, keyfile_dict=self.gcp_keyfile_dict, scopes=scopes ) + + import time + import json + import googleapiclient + + with open(self.gcp_key_path, 'r') as f: + creds = json.load(f) + service_account = creds['client_email'] + + # Generate a payload for subsequent "signJwt()" call + # Reference: https://googleapis.dev/python/google-auth/latest/reference/google.auth.jwt.html#google.auth.jwt.Credentials + now = int(time.time()) + expires = now + 900 # 15 mins in seconds, can't be longer. + payload = { + 'iat': now, + 'exp': expires, + 'sub': credentials, + 'aud': f'vault/{self.role_id}' + } + body = {'payload': json.dumps(payload)} + name = f'projects/{project_id}/serviceAccounts/{service_account}' + + # Perform the GCP API call + iam = googleapiclient.discovery.build('iam', 'v1', credentials=credentials) + request = iam.projects().serviceAccounts().signJwt(name=name, body=body) + resp = request.execute() + jwt = resp['signedJwt'] + if self.auth_mount_point: - _client.auth.gcp.configure(credentials=credentials, mount_point=self.auth_mount_point) + _client.auth.gcp.login( + role=self.role_id, + jwt=jwt, + mount_point=self.auth_mount_point) else: - _client.auth.gcp.configure(credentials=credentials) + _client.auth.gcp.login( + role=self.role_id, + jwt=jwt) + def _auth_azure(self, _client: hvac.Client) -> None: if self.auth_mount_point: diff --git a/tests/providers/hashicorp/_internal_client/test_vault_client.py b/tests/providers/hashicorp/_internal_client/test_vault_client.py index f491f12129007..ec7cc3c2fa641 100644 --- a/tests/providers/hashicorp/_internal_client/test_vault_client.py +++ b/tests/providers/hashicorp/_internal_client/test_vault_client.py @@ -232,15 +232,20 @@ def test_azure_missing_tenant_id(self, mock_hvac): @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes): + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.googleapiclient.discovery.build") + def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes, mock_google_build): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_get_scopes.return_value = ["scope1", "scope2"] mock_get_credentials.return_value = ("credentials", "project_id") + mock_sign_jwt = mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt + mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"} + vault_client = _VaultClient( auth_type="gcp", gcp_key_path="path.json", gcp_scopes="scope1,scope2", + role_id="role", url="http://localhost:8180", session=None, ) @@ -251,8 +256,19 @@ def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes): key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"] ) mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with( - credentials="credentials", + mock_sign_jwt.assert_called_with( + name="projects/project_id/serviceAccounts/service_account", + body={"payload": json.dumps({ + "iat": mock.ANY, + "exp": mock.ANY, + "aud": "vault/role", + "sub": "credentials" + })} + ) + + client.auth.gcp.login.assert_called_with( + role="role", + jwt="mocked_jwt", ) client.is_authenticated.assert_called_with() assert 2 == vault_client.kv_engine_version @@ -260,11 +276,15 @@ def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes): @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp_different_auth_mount_point(self, mock_hvac, mock_get_credentials, mock_get_scopes): + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.googleapiclient.discovery.build") + def test_gcp_different_auth_mount_point(self, mock_hvac, mock_get_credentials, mock_get_scopesm, mock_google_build): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_get_scopes.return_value = ["scope1", "scope2"] mock_get_credentials.return_value = ("credentials", "project_id") + mock_sign_jwt = mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt + mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"} + vault_client = _VaultClient( auth_type="gcp", gcp_key_path="path.json", @@ -280,7 +300,20 @@ def test_gcp_different_auth_mount_point(self, mock_hvac, mock_get_credentials, m key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"] ) mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with(credentials="credentials", mount_point="other") + mock_sign_jwt.assert_called_with( + name="projects/project_id/serviceAccounts/service_account", + body={"payload": json.dumps({ + "iat": mock.ANY, + "exp": mock.ANY, + "aud": "vault/test_role", + "sub": "credentials" + })} + ) + client.auth.gcp.login.assert_called_with( + role="role", + jwt="mocked_jwt", + mount_point="other" + ) client.is_authenticated.assert_called_with() assert 2 == vault_client.kv_engine_version @@ -306,8 +339,9 @@ def test_gcp_dict(self, mock_hvac, mock_get_credentials, mock_get_scopes): key_path=None, keyfile_dict={"key": "value"}, scopes=["scope1", "scope2"] ) mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with( - credentials="credentials", + client.auth.gcp.login.assert_called_with( + role="TODO", + jwt="TODO", ) client.is_authenticated.assert_called_with() assert 2 == vault_client.kv_engine_version