diff --git a/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py b/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py index 98bcc71405d29..73fddd0a1ce9a 100644 --- a/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py +++ b/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py @@ -151,6 +151,15 @@ def __init__( raise VaultError("The 'radius' authentication type requires 'radius_host'") if not radius_secret: raise VaultError("The 'radius' authentication type requires 'radius_secret'") + if auth_type == "gcp": + if not gcp_scopes: + raise VaultError("The 'gcp' authentication type requires 'gcp_scopes'") + if not role_id: + raise VaultError("The 'gcp' authentication type requires 'role_id'") + if not gcp_key_path and not gcp_keyfile_dict: + raise VaultError( + "The 'gcp' authentication type requires 'gcp_key_path' or 'gcp_keyfile_dict'" + ) self.kv_engine_version = kv_engine_version or 2 self.url = url @@ -299,13 +308,36 @@ def _auth_gcp(self, _client: hvac.Client) -> None: ) scopes = _get_scopes(self.gcp_scopes) - credentials, _ = get_credentials_and_project_id( + credentials, project_id = get_credentials_and_project_id( key_path=self.gcp_key_path, keyfile_dict=self.gcp_keyfile_dict, scopes=scopes ) + + import time + import json + import googleapiclient + + with open(self.gcp_key_path, "r") as f: + creds = json.load(f) + service_account = creds["client_email"] + + # Generate a payload for subsequent "signJwt()" call + # Reference: https://googleapis.dev/python/google-auth/latest/reference/google.auth.jwt.html#google.auth.jwt.Credentials + now = int(time.time()) + expires = now + 900 # 15 mins in seconds, can't be longer. + payload = {"iat": now, "exp": expires, "sub": credentials, "aud": f"vault/{self.role_id}"} + body = {"payload": json.dumps(payload)} + name = f"projects/{project_id}/serviceAccounts/{service_account}" + + # Perform the GCP API call + iam = googleapiclient.discovery.build("iam", "v1", credentials=credentials) + request = iam.projects().serviceAccounts().signJwt(name=name, body=body) + resp = request.execute() + jwt = resp["signedJwt"] + if self.auth_mount_point: - _client.auth.gcp.configure(credentials=credentials, mount_point=self.auth_mount_point) + _client.auth.gcp.login(role=self.role_id, jwt=jwt, mount_point=self.auth_mount_point) else: - _client.auth.gcp.configure(credentials=credentials) + _client.auth.gcp.login(role=self.role_id, jwt=jwt) def _auth_azure(self, _client: hvac.Client) -> None: if self.auth_mount_point: diff --git a/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py b/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py index 434b441e8964e..86c25bb0a694d 100644 --- a/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py +++ b/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py @@ -20,12 +20,15 @@ from unittest.mock import mock_open, patch import pytest +import json +import time from hvac.exceptions import InvalidPath, VaultError from requests import Session from requests.adapters import HTTPAdapter from urllib3.util import Retry from airflow.providers.hashicorp._internal_client.vault_client import _VaultClient +import googleapiclient.discovery class TestVaultClient: @@ -229,86 +232,221 @@ def test_azure_missing_tenant_id(self, mock_hvac): secret_id="pass", ) + @mock.patch("builtins.open", create=True) @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") - @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes): + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client") + @mock.patch("googleapiclient.discovery.build") + def test_gcp(self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open): + # Mock the content of the file 'path.json' + mock_file = mock.MagicMock() + mock_file.read.return_value = '{"client_email": "service_account_email"}' + mock_open.return_value.__enter__.return_value = mock_file + mock_client = mock.MagicMock() - mock_hvac.Client.return_value = mock_client + mock_hvac_client.return_value = mock_client mock_get_scopes.return_value = ["scope1", "scope2"] mock_get_credentials.return_value = ("credentials", "project_id") + + # Mock the current time to use for iat and exp + current_time = int(time.time()) + iat = current_time + exp = iat + 3600 # 1 hour after iat + + # Mock the signJwt API to return the expected payload + mock_sign_jwt = ( + mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt + ) + mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"} + vault_client = _VaultClient( auth_type="gcp", gcp_key_path="path.json", gcp_scopes="scope1,scope2", + role_id="role", url="http://localhost:8180", session=None, ) - client = vault_client.client - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) + + # Preserve the original json.dumps + original_json_dumps = json.dumps + + # Inject the mocked payload into the JWT signing process + with mock.patch("json.dumps") as mock_json_dumps: + + def mocked_json_dumps(payload): + # Override the payload to inject controlled iat and exp values + payload["iat"] = iat + payload["exp"] = exp + return original_json_dumps(payload) # Use the original json.dumps + + mock_json_dumps.side_effect = mocked_json_dumps + + client = vault_client.client # Trigger the Vault client creation + + # Validate that the HVAC client and other mocks are called correctly + mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None) mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"] ) - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with( - credentials="credentials", - ) + + # Extract the arguments passed to the mocked signJwt API + args, kwargs = mock_sign_jwt.call_args + payload = json.loads(kwargs["body"]["payload"]) + + # Assert iat and exp values are as expected + assert payload["iat"] == iat + assert payload["exp"] == exp + assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat + + client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt") client.is_authenticated.assert_called_with() assert vault_client.kv_engine_version == 2 + @mock.patch("builtins.open", create=True) @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") - @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp_different_auth_mount_point(self, mock_hvac, mock_get_credentials, mock_get_scopes): + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client") + @mock.patch("googleapiclient.discovery.build") + def test_gcp_different_auth_mount_point( + self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open + ): + # Mock the content of the file 'path.json' + mock_file = mock.MagicMock() + mock_file.read.return_value = '{"client_email": "service_account_email"}' + mock_open.return_value.__enter__.return_value = mock_file + mock_client = mock.MagicMock() - mock_hvac.Client.return_value = mock_client + mock_hvac_client.return_value = mock_client mock_get_scopes.return_value = ["scope1", "scope2"] mock_get_credentials.return_value = ("credentials", "project_id") + + mock_sign_jwt = ( + mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt + ) + mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"} + + # Generate realistic iat and exp values + current_time = int(time.time()) + iat = current_time + exp = current_time + 3600 # 1 hour later + vault_client = _VaultClient( auth_type="gcp", gcp_key_path="path.json", gcp_scopes="scope1,scope2", + role_id="role", url="http://localhost:8180", auth_mount_point="other", session=None, ) - client = vault_client.client - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) + + # Preserve the original json.dumps + original_json_dumps = json.dumps + + # Inject the mocked payload into the JWT signing process + with mock.patch("json.dumps") as mock_json_dumps: + + def mocked_json_dumps(payload): + # Override the payload to inject controlled iat and exp values + payload["iat"] = iat + payload["exp"] = exp + return original_json_dumps(payload) # Use the original json.dumps + + mock_json_dumps.side_effect = mocked_json_dumps + + client = vault_client.client # Trigger the Vault client creation + + # Assertions + mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None) mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"] ) - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with(credentials="credentials", mount_point="other") + # Extract the arguments passed to the mocked signJwt API + args, kwargs = mock_sign_jwt.call_args + payload = json.loads(kwargs["body"]["payload"]) + + # Assert iat and exp values are as expected + assert payload["iat"] == iat + assert payload["exp"] == exp + assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat + + client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt", mount_point="other") client.is_authenticated.assert_called_with() assert vault_client.kv_engine_version == 2 + @mock.patch("builtins.open", create=True) @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") - @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp_dict(self, mock_hvac, mock_get_credentials, mock_get_scopes): + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client") + @mock.patch("googleapiclient.discovery.build") + def test_gcp_dict( + self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open + ): + # Mock the content of the file 'path.json' + mock_file = mock.MagicMock() + mock_file.read.return_value = '{"client_email": "service_account_email"}' + mock_open.return_value.__enter__.return_value = mock_file + + # Mock the content of the keyfile dict mock_client = mock.MagicMock() - mock_hvac.Client.return_value = mock_client + mock_hvac_client.return_value = mock_client mock_get_scopes.return_value = ["scope1", "scope2"] mock_get_credentials.return_value = ("credentials", "project_id") + + mock_sign_jwt = ( + mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt + ) + mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"} + + # Generate realistic iat and exp values + current_time = int(time.time()) + iat = current_time + exp = current_time + 3600 # 1 hour later + vault_client = _VaultClient( auth_type="gcp", gcp_keyfile_dict={"key": "value"}, gcp_scopes="scope1,scope2", + role_id="role", url="http://localhost:8180", session=None, ) - client = vault_client.client - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) + + # Preserve the original json.dumps + original_json_dumps = json.dumps + + # Inject the mocked payload into the JWT signing process + with mock.patch("json.dumps") as mock_json_dumps: + + def mocked_json_dumps(payload): + # Override the payload to inject controlled iat and exp values + payload["iat"] = iat + payload["exp"] = exp + return original_json_dumps(payload) # Use the original json.dumps + + mock_json_dumps.side_effect = mocked_json_dumps + + client = vault_client.client # Trigger the Vault client creation + + # Assertions + mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None) mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( key_path=None, keyfile_dict={"key": "value"}, scopes=["scope1", "scope2"] ) - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with( - credentials="credentials", - ) + # Extract the arguments passed to the mocked signJwt API + args, kwargs = mock_sign_jwt.call_args + payload = json.loads(kwargs["body"]["payload"]) + + # Assert iat and exp values are as expected + assert payload["iat"] == iat + assert payload["exp"] == exp + assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat + + client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt") client.is_authenticated.assert_called_with() assert vault_client.kv_engine_version == 2 diff --git a/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py b/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py index f835382b1200a..965066b172c7b 100644 --- a/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py +++ b/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py @@ -448,6 +448,7 @@ def test_gcp_init_params(self, mock_hvac, mock_get_connection, mock_get_credenti "auth_type": "gcp", "gcp_key_path": "path.json", "gcp_scopes": "scope1,scope2", + "role_id": "role", "session": None, } @@ -481,6 +482,7 @@ def test_gcp_dejson(self, mock_hvac, mock_get_connection, mock_get_credentials, "auth_type": "gcp", "gcp_key_path": "path.json", "gcp_scopes": "scope1,scope2", + "role_id": "role", } mock_connection.extra_dejson.get.side_effect = connection_dict.get @@ -519,12 +521,14 @@ def test_gcp_dict_dejson(self, mock_hvac, mock_get_connection, mock_get_credenti "auth_type": "gcp", "gcp_keyfile_dict": '{"key": "value"}', "gcp_scopes": "scope1,scope2", + "role_id": "role", } mock_connection.extra_dejson.get.side_effect = connection_dict.get kwargs = { "vault_conn_id": "vault_conn_id", "session": None, + "role_id": "role", } test_hook = VaultHook(**kwargs)