Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ def __init__(
raise VaultError("The 'radius' authentication type requires 'radius_host'")
if not radius_secret:
raise VaultError("The 'radius' authentication type requires 'radius_secret'")
if auth_type == "gcp":
if not gcp_scopes:
raise VaultError("The 'gcp' authentication type requires 'gcp_scopes'")
if not role_id:
raise VaultError("The 'gcp' authentication type requires 'role_id'")
if not gcp_key_path and not gcp_keyfile_dict:
raise VaultError(
"The 'gcp' authentication type requires 'gcp_key_path' or 'gcp_keyfile_dict'"
)

self.kv_engine_version = kv_engine_version or 2
self.url = url
Expand Down Expand Up @@ -299,13 +308,36 @@ def _auth_gcp(self, _client: hvac.Client) -> None:
)

scopes = _get_scopes(self.gcp_scopes)
credentials, _ = get_credentials_and_project_id(
credentials, project_id = get_credentials_and_project_id(
key_path=self.gcp_key_path, keyfile_dict=self.gcp_keyfile_dict, scopes=scopes
)

import time
import json
import googleapiclient

with open(self.gcp_key_path, "r") as f:
creds = json.load(f)
service_account = creds["client_email"]

# Generate a payload for subsequent "signJwt()" call
# Reference: https://googleapis.dev/python/google-auth/latest/reference/google.auth.jwt.html#google.auth.jwt.Credentials
now = int(time.time())
expires = now + 900 # 15 mins in seconds, can't be longer.
payload = {"iat": now, "exp": expires, "sub": credentials, "aud": f"vault/{self.role_id}"}
body = {"payload": json.dumps(payload)}
name = f"projects/{project_id}/serviceAccounts/{service_account}"

# Perform the GCP API call
iam = googleapiclient.discovery.build("iam", "v1", credentials=credentials)
request = iam.projects().serviceAccounts().signJwt(name=name, body=body)
resp = request.execute()
jwt = resp["signedJwt"]

if self.auth_mount_point:
_client.auth.gcp.configure(credentials=credentials, mount_point=self.auth_mount_point)
_client.auth.gcp.login(role=self.role_id, jwt=jwt, mount_point=self.auth_mount_point)
else:
_client.auth.gcp.configure(credentials=credentials)
_client.auth.gcp.login(role=self.role_id, jwt=jwt)

def _auth_azure(self, _client: hvac.Client) -> None:
if self.auth_mount_point:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
from unittest.mock import mock_open, patch

import pytest
import json
import time
from hvac.exceptions import InvalidPath, VaultError
from requests import Session
from requests.adapters import HTTPAdapter
from urllib3.util import Retry

from airflow.providers.hashicorp._internal_client.vault_client import _VaultClient
import googleapiclient.discovery


class TestVaultClient:
Expand Down Expand Up @@ -229,86 +232,221 @@ def test_azure_missing_tenant_id(self, mock_hvac):
secret_id="pass",
)

@mock.patch("builtins.open", create=True)
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes):
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
@mock.patch("googleapiclient.discovery.build")
def test_gcp(self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open):
# Mock the content of the file 'path.json'
mock_file = mock.MagicMock()
mock_file.read.return_value = '{"client_email": "service_account_email"}'
mock_open.return_value.__enter__.return_value = mock_file

mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
mock_hvac_client.return_value = mock_client
mock_get_scopes.return_value = ["scope1", "scope2"]
mock_get_credentials.return_value = ("credentials", "project_id")

# Mock the current time to use for iat and exp
current_time = int(time.time())
iat = current_time
exp = iat + 3600 # 1 hour after iat

# Mock the signJwt API to return the expected payload
mock_sign_jwt = (
mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
)
mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"}

vault_client = _VaultClient(
auth_type="gcp",
gcp_key_path="path.json",
gcp_scopes="scope1,scope2",
role_id="role",
url="http://localhost:8180",
session=None,
)
client = vault_client.client
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)

# Preserve the original json.dumps
original_json_dumps = json.dumps

# Inject the mocked payload into the JWT signing process
with mock.patch("json.dumps") as mock_json_dumps:

def mocked_json_dumps(payload):
# Override the payload to inject controlled iat and exp values
payload["iat"] = iat
payload["exp"] = exp
return original_json_dumps(payload) # Use the original json.dumps

mock_json_dumps.side_effect = mocked_json_dumps

client = vault_client.client # Trigger the Vault client creation

# Validate that the HVAC client and other mocks are called correctly
mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None)
mock_get_scopes.assert_called_with("scope1,scope2")
mock_get_credentials.assert_called_with(
key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"]
)
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)
client.auth.gcp.configure.assert_called_with(
credentials="credentials",
)

# Extract the arguments passed to the mocked signJwt API
args, kwargs = mock_sign_jwt.call_args
payload = json.loads(kwargs["body"]["payload"])

# Assert iat and exp values are as expected
assert payload["iat"] == iat
assert payload["exp"] == exp
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat

client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt")
client.is_authenticated.assert_called_with()
assert vault_client.kv_engine_version == 2

@mock.patch("builtins.open", create=True)
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_gcp_different_auth_mount_point(self, mock_hvac, mock_get_credentials, mock_get_scopes):
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
@mock.patch("googleapiclient.discovery.build")
def test_gcp_different_auth_mount_point(
self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open
):
# Mock the content of the file 'path.json'
mock_file = mock.MagicMock()
mock_file.read.return_value = '{"client_email": "service_account_email"}'
mock_open.return_value.__enter__.return_value = mock_file

mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
mock_hvac_client.return_value = mock_client
mock_get_scopes.return_value = ["scope1", "scope2"]
mock_get_credentials.return_value = ("credentials", "project_id")

mock_sign_jwt = (
mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
)
mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"}

# Generate realistic iat and exp values
current_time = int(time.time())
iat = current_time
exp = current_time + 3600 # 1 hour later

vault_client = _VaultClient(
auth_type="gcp",
gcp_key_path="path.json",
gcp_scopes="scope1,scope2",
role_id="role",
url="http://localhost:8180",
auth_mount_point="other",
session=None,
)
client = vault_client.client
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)

# Preserve the original json.dumps
original_json_dumps = json.dumps

# Inject the mocked payload into the JWT signing process
with mock.patch("json.dumps") as mock_json_dumps:

def mocked_json_dumps(payload):
# Override the payload to inject controlled iat and exp values
payload["iat"] = iat
payload["exp"] = exp
return original_json_dumps(payload) # Use the original json.dumps

mock_json_dumps.side_effect = mocked_json_dumps

client = vault_client.client # Trigger the Vault client creation

# Assertions
mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None)
mock_get_scopes.assert_called_with("scope1,scope2")
mock_get_credentials.assert_called_with(
key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"]
)
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)
client.auth.gcp.configure.assert_called_with(credentials="credentials", mount_point="other")
# Extract the arguments passed to the mocked signJwt API
args, kwargs = mock_sign_jwt.call_args
payload = json.loads(kwargs["body"]["payload"])

# Assert iat and exp values are as expected
assert payload["iat"] == iat
assert payload["exp"] == exp
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat

client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt", mount_point="other")
client.is_authenticated.assert_called_with()
assert vault_client.kv_engine_version == 2

@mock.patch("builtins.open", create=True)
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_gcp_dict(self, mock_hvac, mock_get_credentials, mock_get_scopes):
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
@mock.patch("googleapiclient.discovery.build")
def test_gcp_dict(
self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open
):
# Mock the content of the file 'path.json'
mock_file = mock.MagicMock()
mock_file.read.return_value = '{"client_email": "service_account_email"}'
mock_open.return_value.__enter__.return_value = mock_file

# Mock the content of the keyfile dict
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
mock_hvac_client.return_value = mock_client
mock_get_scopes.return_value = ["scope1", "scope2"]
mock_get_credentials.return_value = ("credentials", "project_id")

mock_sign_jwt = (
mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
)
mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"}

# Generate realistic iat and exp values
current_time = int(time.time())
iat = current_time
exp = current_time + 3600 # 1 hour later

vault_client = _VaultClient(
auth_type="gcp",
gcp_keyfile_dict={"key": "value"},
gcp_scopes="scope1,scope2",
role_id="role",
url="http://localhost:8180",
session=None,
)
client = vault_client.client
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)

# Preserve the original json.dumps
original_json_dumps = json.dumps

# Inject the mocked payload into the JWT signing process
with mock.patch("json.dumps") as mock_json_dumps:

def mocked_json_dumps(payload):
# Override the payload to inject controlled iat and exp values
payload["iat"] = iat
payload["exp"] = exp
return original_json_dumps(payload) # Use the original json.dumps

mock_json_dumps.side_effect = mocked_json_dumps

client = vault_client.client # Trigger the Vault client creation

# Assertions
mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None)
mock_get_scopes.assert_called_with("scope1,scope2")
mock_get_credentials.assert_called_with(
key_path=None, keyfile_dict={"key": "value"}, scopes=["scope1", "scope2"]
)
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)
client.auth.gcp.configure.assert_called_with(
credentials="credentials",
)
# Extract the arguments passed to the mocked signJwt API
args, kwargs = mock_sign_jwt.call_args
payload = json.loads(kwargs["body"]["payload"])

# Assert iat and exp values are as expected
assert payload["iat"] == iat
assert payload["exp"] == exp
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat

client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt")
client.is_authenticated.assert_called_with()
assert vault_client.kv_engine_version == 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def test_gcp_init_params(self, mock_hvac, mock_get_connection, mock_get_credenti
"auth_type": "gcp",
"gcp_key_path": "path.json",
"gcp_scopes": "scope1,scope2",
"role_id": "role",
"session": None,
}

Expand Down Expand Up @@ -481,6 +482,7 @@ def test_gcp_dejson(self, mock_hvac, mock_get_connection, mock_get_credentials,
"auth_type": "gcp",
"gcp_key_path": "path.json",
"gcp_scopes": "scope1,scope2",
"role_id": "role",
}

mock_connection.extra_dejson.get.side_effect = connection_dict.get
Expand Down Expand Up @@ -519,12 +521,14 @@ def test_gcp_dict_dejson(self, mock_hvac, mock_get_connection, mock_get_credenti
"auth_type": "gcp",
"gcp_keyfile_dict": '{"key": "value"}',
"gcp_scopes": "scope1,scope2",
"role_id": "role",
}

mock_connection.extra_dejson.get.side_effect = connection_dict.get
kwargs = {
"vault_conn_id": "vault_conn_id",
"session": None,
"role_id": "role",
}

test_hook = VaultHook(**kwargs)
Expand Down
Loading