diff --git a/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py b/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py index edc87827857af..aa73ffdc8b7db 100644 --- a/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py +++ b/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py @@ -153,6 +153,15 @@ def __init__( raise VaultError("The 'radius' authentication type requires 'radius_host'") if not radius_secret: raise VaultError("The 'radius' authentication type requires 'radius_secret'") + if auth_type == "gcp": + if not gcp_scopes: + raise VaultError("The 'gcp' authentication type requires 'gcp_scopes'") + if not role_id: + raise VaultError("The 'gcp' authentication type requires 'role_id'") + if not gcp_key_path and not gcp_keyfile_dict: + raise VaultError( + "The 'gcp' authentication type requires 'gcp_key_path' or 'gcp_keyfile_dict'" + ) self.kv_engine_version = kv_engine_version or 2 self.url = url @@ -303,13 +312,41 @@ def _auth_gcp(self, _client: hvac.Client) -> None: ) scopes = _get_scopes(self.gcp_scopes) - credentials, _ = get_credentials_and_project_id( + credentials, project_id = get_credentials_and_project_id( key_path=self.gcp_key_path, keyfile_dict=self.gcp_keyfile_dict, scopes=scopes ) + + import json + import time + + import googleapiclient + + if self.gcp_keyfile_dict: + creds = self.gcp_keyfile_dict + elif self.gcp_key_path: + with open(self.gcp_key_path) as f: + creds = json.load(f) + + service_account = creds["client_email"] + + # Generate a payload for subsequent "signJwt()" call + # Reference: https://googleapis.dev/python/google-auth/latest/reference/google.auth.jwt.html#google.auth.jwt.Credentials + now = int(time.time()) + expires = now + 900 # 15 mins in seconds, can't be longer. + payload = {"iat": now, "exp": expires, "sub": credentials, "aud": f"vault/{self.role_id}"} + body = {"payload": json.dumps(payload)} + name = f"projects/{project_id}/serviceAccounts/{service_account}" + + # Perform the GCP API call + iam = googleapiclient.discovery.build("iam", "v1", credentials=credentials) + request = iam.projects().serviceAccounts().signJwt(name=name, body=body) + resp = request.execute() + jwt = resp["signedJwt"] + if self.auth_mount_point: - _client.auth.gcp.configure(credentials=credentials, mount_point=self.auth_mount_point) + _client.auth.gcp.login(role=self.role_id, jwt=jwt, mount_point=self.auth_mount_point) else: - _client.auth.gcp.configure(credentials=credentials) + _client.auth.gcp.login(role=self.role_id, jwt=jwt) def _auth_azure(self, _client: hvac.Client) -> None: if self.auth_mount_point: diff --git a/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py b/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py index 8b98a814394d5..c9239b75a99e4 100644 --- a/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py +++ b/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +import json +import time from unittest import mock from unittest.mock import mock_open, patch @@ -253,86 +255,217 @@ def test_azure_missing_tenant_id(self, mock_hvac): secret_id="pass", ) + @mock.patch("builtins.open", create=True) @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") - @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes): + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client") + @mock.patch("googleapiclient.discovery.build") + def test_gcp(self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open): + # Mock the content of the file 'path.json' + mock_file = mock.MagicMock() + mock_file.read.return_value = '{"client_email": "service_account_email"}' + mock_open.return_value.__enter__.return_value = mock_file + mock_client = mock.MagicMock() - mock_hvac.Client.return_value = mock_client + mock_hvac_client.return_value = mock_client mock_get_scopes.return_value = ["scope1", "scope2"] mock_get_credentials.return_value = ("credentials", "project_id") + + # Mock the current time to use for iat and exp + current_time = int(time.time()) + iat = current_time + exp = iat + 3600 # 1 hour after iat + + # Mock the signJwt API to return the expected payload + mock_sign_jwt = ( + mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt + ) + mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"} + vault_client = _VaultClient( auth_type="gcp", gcp_key_path="path.json", gcp_scopes="scope1,scope2", + role_id="role", url="http://localhost:8180", session=None, ) - client = vault_client.client - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) + + # Preserve the original json.dumps + original_json_dumps = json.dumps + + # Inject the mocked payload into the JWT signing process + with mock.patch("json.dumps") as mock_json_dumps: + + def mocked_json_dumps(payload): + # Override the payload to inject controlled iat and exp values + payload["iat"] = iat + payload["exp"] = exp + return original_json_dumps(payload) # Use the original json.dumps + + mock_json_dumps.side_effect = mocked_json_dumps + + client = vault_client.client # Trigger the Vault client creation + + # Validate that the HVAC client and other mocks are called correctly + mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None) mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"] ) - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with( - credentials="credentials", - ) + + # Extract the arguments passed to the mocked signJwt API + args, kwargs = mock_sign_jwt.call_args + payload = json.loads(kwargs["body"]["payload"]) + + # Assert iat and exp values are as expected + assert payload["iat"] == iat + assert payload["exp"] == exp + assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat + + client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt") client.is_authenticated.assert_called_with() assert vault_client.kv_engine_version == 2 + @mock.patch("builtins.open", create=True) @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") - @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp_different_auth_mount_point(self, mock_hvac, mock_get_credentials, mock_get_scopes): - mock_client = mock.MagicMock() - mock_hvac.Client.return_value = mock_client + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client") + @mock.patch("googleapiclient.discovery.build") + def test_gcp_different_auth_mount_point( + self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open + ): + # Mock the content of the file 'path.json' + mock_file = mock.MagicMock() + mock_file.read.return_value = '{"client_email": "service_account_email"}' + mock_open.return_value.__enter__.return_value = mock_file + + mock_client = mock.MagicMock() + mock_hvac_client.return_value = mock_client mock_get_scopes.return_value = ["scope1", "scope2"] mock_get_credentials.return_value = ("credentials", "project_id") + + mock_sign_jwt = ( + mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt + ) + mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"} + + # Generate realistic iat and exp values + current_time = int(time.time()) + iat = current_time + exp = current_time + 3600 # 1 hour later + vault_client = _VaultClient( auth_type="gcp", gcp_key_path="path.json", gcp_scopes="scope1,scope2", + role_id="role", url="http://localhost:8180", auth_mount_point="other", session=None, ) - client = vault_client.client - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) + + # Preserve the original json.dumps + original_json_dumps = json.dumps + + # Inject the mocked payload into the JWT signing process + with mock.patch("json.dumps") as mock_json_dumps: + + def mocked_json_dumps(payload): + # Override the payload to inject controlled iat and exp values + payload["iat"] = iat + payload["exp"] = exp + return original_json_dumps(payload) # Use the original json.dumps + + mock_json_dumps.side_effect = mocked_json_dumps + + client = vault_client.client # Trigger the Vault client creation + + # Assertions + mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None) mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"] ) - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with(credentials="credentials", mount_point="other") + # Extract the arguments passed to the mocked signJwt API + args, kwargs = mock_sign_jwt.call_args + payload = json.loads(kwargs["body"]["payload"]) + + # Assert iat and exp values are as expected + assert payload["iat"] == iat + assert payload["exp"] == exp + assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat + + client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt", mount_point="other") client.is_authenticated.assert_called_with() assert vault_client.kv_engine_version == 2 + @mock.patch( + "builtins.open", new_callable=mock_open, read_data='{"client_email": "service_account_email"}' + ) @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") - @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp_dict(self, mock_hvac, mock_get_credentials, mock_get_scopes): + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client") + @mock.patch("googleapiclient.discovery.build") + def test_gcp_dict( + self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_file + ): mock_client = mock.MagicMock() - mock_hvac.Client.return_value = mock_client + mock_hvac_client.return_value = mock_client mock_get_scopes.return_value = ["scope1", "scope2"] mock_get_credentials.return_value = ("credentials", "project_id") + + mock_sign_jwt = ( + mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt + ) + mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"} + + # Generate realistic iat and exp values + current_time = int(time.time()) + iat = current_time + exp = current_time + 3600 # 1 hour later + vault_client = _VaultClient( auth_type="gcp", - gcp_keyfile_dict={"key": "value"}, + gcp_keyfile_dict={"client_email": "service_account_email"}, gcp_scopes="scope1,scope2", + role_id="role", url="http://localhost:8180", session=None, ) - client = vault_client.client - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) + + # Preserve the original json.dumps + original_json_dumps = json.dumps + + # Inject the mocked payload into the JWT signing process + with mock.patch("json.dumps") as mock_json_dumps: + + def mocked_json_dumps(payload): + # Override the payload to inject controlled iat and exp values + payload["iat"] = iat + payload["exp"] = exp + return original_json_dumps(payload) # Use the original json.dumps + + mock_json_dumps.side_effect = mocked_json_dumps + + client = vault_client.client # Trigger the Vault client creation + + # Assertions + mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None) mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( - key_path=None, keyfile_dict={"key": "value"}, scopes=["scope1", "scope2"] - ) - mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - client.auth.gcp.configure.assert_called_with( - credentials="credentials", + key_path=None, keyfile_dict={"client_email": "service_account_email"}, scopes=["scope1", "scope2"] ) + # Extract the arguments passed to the mocked signJwt API + args, kwargs = mock_sign_jwt.call_args + payload = json.loads(kwargs["body"]["payload"]) + + # Assert iat and exp values are as expected + assert payload["iat"] == iat + assert payload["exp"] == exp + assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat + + client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt") client.is_authenticated.assert_called_with() assert vault_client.kv_engine_version == 2 diff --git a/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py b/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py index fd0862329e7ef..fa573eb6624a1 100644 --- a/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py +++ b/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py @@ -18,7 +18,7 @@ import re from unittest import mock -from unittest.mock import PropertyMock, mock_open, patch +from unittest.mock import MagicMock, PropertyMock, mock_open, patch import pytest from hvac.exceptions import VaultError @@ -431,7 +431,10 @@ def test_azure_dejson(self, mock_hvac, mock_get_connection): @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp_init_params(self, mock_hvac, mock_get_connection, mock_get_credentials, mock_get_scopes): + @mock.patch("googleapiclient.discovery.build") + def test_gcp_init_params( + self, mock_build, mock_hvac, mock_get_connection, mock_get_credentials, mock_get_scopes + ): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_connection = self.get_mock_connection() @@ -439,6 +442,17 @@ def test_gcp_init_params(self, mock_hvac, mock_get_connection, mock_get_credenti mock_get_scopes.return_value = ["scope1", "scope2"] mock_get_credentials.return_value = ("credentials", "project_id") + # Mock googleapiclient.discovery.build chain + mock_service = MagicMock() + mock_projects = MagicMock() + mock_service_accounts = MagicMock() + mock_sign_jwt = MagicMock() + mock_sign_jwt.execute.return_value = {"signedJwt": "mocked_jwt"} + mock_service_accounts.signJwt.return_value = mock_sign_jwt + mock_projects.serviceAccounts.return_value = mock_service_accounts + mock_service.projects.return_value = mock_projects + mock_build.return_value = mock_service + connection_dict = {} mock_connection.extra_dejson.get.side_effect = connection_dict.get @@ -447,20 +461,24 @@ def test_gcp_init_params(self, mock_hvac, mock_get_connection, mock_get_credenti "auth_type": "gcp", "gcp_key_path": "path.json", "gcp_scopes": "scope1,scope2", + "role_id": "role", "session": None, } - test_hook = VaultHook(**kwargs) - test_client = test_hook.get_conn() + with patch( + "builtins.open", mock_open(read_data='{"client_email": "service_account_email"}') + ) as mock_file: + test_hook = VaultHook(**kwargs) + test_client = test_hook.get_conn() + mock_file.assert_called_with("path.json") + mock_get_connection.assert_called_with("vault_conn_id") mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"] ) mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - test_client.auth.gcp.configure.assert_called_with( - credentials="credentials", - ) + test_client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt") test_client.is_authenticated.assert_called_with() assert test_hook.vault_client.kv_engine_version == 2 @@ -468,7 +486,10 @@ def test_gcp_init_params(self, mock_hvac, mock_get_connection, mock_get_credenti @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp_dejson(self, mock_hvac, mock_get_connection, mock_get_credentials, mock_get_scopes): + @mock.patch("googleapiclient.discovery.build") + def test_gcp_dejson( + self, mock_build, mock_hvac, mock_get_connection, mock_get_credentials, mock_get_scopes + ): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_connection = self.get_mock_connection() @@ -476,29 +497,45 @@ def test_gcp_dejson(self, mock_hvac, mock_get_connection, mock_get_credentials, mock_get_scopes.return_value = ["scope1", "scope2"] mock_get_credentials.return_value = ("credentials", "project_id") + # Mock googleapiclient.discovery.build chain + mock_service = MagicMock() + mock_projects = MagicMock() + mock_service_accounts = MagicMock() + mock_sign_jwt = MagicMock() + mock_sign_jwt.execute.return_value = {"signedJwt": "mocked_jwt"} + mock_service_accounts.signJwt.return_value = mock_sign_jwt + mock_projects.serviceAccounts.return_value = mock_service_accounts + mock_service.projects.return_value = mock_projects + mock_build.return_value = mock_service + connection_dict = { "auth_type": "gcp", "gcp_key_path": "path.json", "gcp_scopes": "scope1,scope2", + "role_id": "role", } mock_connection.extra_dejson.get.side_effect = connection_dict.get kwargs = { "vault_conn_id": "vault_conn_id", "session": None, + "role_id": "role", } - test_hook = VaultHook(**kwargs) - test_client = test_hook.get_conn() + with patch( + "builtins.open", mock_open(read_data='{"client_email": "service_account_email"}') + ) as mock_file: + test_hook = VaultHook(**kwargs) + test_client = test_hook.get_conn() + mock_file.assert_called_with("path.json") + mock_get_connection.assert_called_with("vault_conn_id") mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"] ) mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - test_client.auth.gcp.configure.assert_called_with( - credentials="credentials", - ) + test_client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt") test_client.is_authenticated.assert_called_with() assert test_hook.vault_client.kv_engine_version == 2 @@ -506,7 +543,10 @@ def test_gcp_dejson(self, mock_hvac, mock_get_connection, mock_get_credentials, @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp_dict_dejson(self, mock_hvac, mock_get_connection, mock_get_credentials, mock_get_scopes): + @mock.patch("googleapiclient.discovery.build") + def test_gcp_dict_dejson( + self, mock_build, mock_hvac, mock_get_connection, mock_get_credentials, mock_get_scopes + ): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_connection = self.get_mock_connection() @@ -514,16 +554,29 @@ def test_gcp_dict_dejson(self, mock_hvac, mock_get_connection, mock_get_credenti mock_get_scopes.return_value = ["scope1", "scope2"] mock_get_credentials.return_value = ("credentials", "project_id") + # Mock googleapiclient.discovery.build chain + mock_service = MagicMock() + mock_projects = MagicMock() + mock_service_accounts = MagicMock() + mock_sign_jwt = MagicMock() + mock_sign_jwt.execute.return_value = {"signedJwt": "mocked_jwt"} + mock_service_accounts.signJwt.return_value = mock_sign_jwt + mock_projects.serviceAccounts.return_value = mock_service_accounts + mock_service.projects.return_value = mock_projects + mock_build.return_value = mock_service + connection_dict = { "auth_type": "gcp", - "gcp_keyfile_dict": '{"key": "value"}', + "gcp_keyfile_dict": '{"client_email": "service_account_email"}', "gcp_scopes": "scope1,scope2", + "role_id": "role", } mock_connection.extra_dejson.get.side_effect = connection_dict.get kwargs = { "vault_conn_id": "vault_conn_id", "session": None, + "role_id": "role", } test_hook = VaultHook(**kwargs) @@ -531,12 +584,10 @@ def test_gcp_dict_dejson(self, mock_hvac, mock_get_connection, mock_get_credenti mock_get_connection.assert_called_with("vault_conn_id") mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( - key_path=None, keyfile_dict={"key": "value"}, scopes=["scope1", "scope2"] + key_path=None, keyfile_dict={"client_email": "service_account_email"}, scopes=["scope1", "scope2"] ) mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) - test_client.auth.gcp.configure.assert_called_with( - credentials="credentials", - ) + test_client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt") test_client.is_authenticated.assert_called_with() assert test_hook.vault_client.kv_engine_version == 2