Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class _VaultClient(LoggingMixin):
:param assume_role_kwargs: AWS assume role param.
See AWS STS Docs:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html
:param region: AWS region for STS API calls. Inferred from the boto3 client configuration if not provided
(for ``aws_iam`` auth_type).
:param kubernetes_role: Role for Authentication (for ``kubernetes`` auth_type).
:param kubernetes_jwt_path: Path for kubernetes jwt token (for ``kubernetes`` auth_type, default:
``/var/run/secrets/kubernetes.io/serviceaccount/token``).
Expand Down Expand Up @@ -108,6 +110,7 @@ def __init__(
secret_id: str | None = None,
assume_role_kwargs: dict | None = None,
role_id: str | None = None,
region: str | None = None,
kubernetes_role: str | None = None,
kubernetes_jwt_path: str | None = "/var/run/secrets/kubernetes.io/serviceaccount/token",
gcp_key_path: str | None = None,
Expand Down Expand Up @@ -166,6 +169,7 @@ def __init__(
self.secret_id = secret_id
self.role_id = role_id
self.assume_role_kwargs = assume_role_kwargs
self.region = region
self.kubernetes_role = kubernetes_role
self.kubernetes_jwt_path = kubernetes_jwt_path
self.gcp_key_path = gcp_key_path
Expand Down Expand Up @@ -329,7 +333,6 @@ def _auth_aws_iam(self, _client: hvac.Client) -> None:
auth_args = {
"access_key": self.key_id,
"secret_key": self.secret_id,
"role": self.role_id,
}
else:
import boto3
Expand All @@ -341,6 +344,7 @@ def _auth_aws_iam(self, _client: hvac.Client) -> None:
"access_key": credentials["Credentials"]["AccessKeyId"],
"secret_key": credentials["Credentials"]["SecretAccessKey"],
"session_token": credentials["Credentials"]["SessionToken"],
"region": sts_client.meta.region_name,
}
else:
session = boto3.Session()
Expand All @@ -349,10 +353,15 @@ def _auth_aws_iam(self, _client: hvac.Client) -> None:
"access_key": credentials.access_key,
"secret_key": credentials.secret_key,
"session_token": credentials.token,
"region": session.region_name,
}

if self.auth_mount_point:
auth_args["mount_point"] = self.auth_mount_point
if self.region:
auth_args["region"] = self.region
if self.role_id:
auth_args["role"] = self.role_id

_client.auth.aws.iam_login(**auth_args)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class VaultHook(BaseHook):
:param kv_engine_version: Select the version of the engine to run (``1`` or ``2``). Defaults to
version defined in connection or ``2`` if not defined in connection.
:param role_id: Role ID for ``aws_iam`` Authentication.
:param region: AWS region for STS API calls (for ``aws_iam`` auth_type).
:param kubernetes_role: Role for Authentication (for ``kubernetes`` auth_type)
:param kubernetes_jwt_path: Path for kubernetes jwt token (for ``kubernetes`` auth_type, default:
``/var/run/secrets/kubernetes.io/serviceaccount/token``)
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(
auth_mount_point: str | None = None,
kv_engine_version: int | None = None,
role_id: str | None = None,
region: str | None = None,
kubernetes_role: str | None = None,
kubernetes_jwt_path: str | None = None,
token_path: str | None = None,
Expand Down Expand Up @@ -151,6 +153,8 @@ def __init__(
if auth_type == "aws_iam":
if not role_id:
role_id = self.connection.extra_dejson.get("role_id")
if not region:
region = self.connection.extra_dejson.get("region")

azure_resource, azure_tenant_id = (
self._get_azure_parameters_from_connection(azure_resource, azure_tenant_id)
Expand Down Expand Up @@ -210,6 +214,7 @@ def __init__(
key_id=self.connection.login,
secret_id=self.connection.password,
role_id=role_id,
region=region,
kubernetes_role=kubernetes_role,
kubernetes_jwt_path=kubernetes_jwt_path,
gcp_key_path=gcp_key_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class VaultBackend(BaseSecretsBackend, LoggingMixin):
:param assume_role_kwargs: AWS assume role param.
See AWS STS Docs:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html
:param region: AWS region for STS API calls (for ``aws_iam`` auth_type).
:param kubernetes_role: Role for Authentication (for ``kubernetes`` auth_type).
:param kubernetes_jwt_path: Path for kubernetes jwt token (for ``kubernetes`` auth_type, default:
``/var/run/secrets/kubernetes.io/serviceaccount/token``).
Expand Down Expand Up @@ -108,6 +109,7 @@ def __init__(
secret_id: str | None = None,
role_id: str | None = None,
assume_role_kwargs: dict | None = None,
region: str | None = None,
kubernetes_role: str | None = None,
kubernetes_jwt_path: str = "/var/run/secrets/kubernetes.io/serviceaccount/token",
gcp_key_path: str | None = None,
Expand Down Expand Up @@ -149,6 +151,7 @@ def __init__(
secret_id=secret_id,
role_id=role_id,
assume_role_kwargs=assume_role_kwargs,
region=region,
kubernetes_role=kubernetes_role,
kubernetes_jwt_path=kubernetes_jwt_path,
gcp_key_path=gcp_key_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,30 @@ def test_aws_iam_different_auth_mount_point(self, mock_hvac):
client.is_authenticated.assert_called_with()
assert vault_client.kv_engine_version == 2

@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_aws_iam_different_region(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
vault_client = _VaultClient(
auth_type="aws_iam",
role_id="role",
url="http://localhost:8180",
key_id="user",
secret_id="pass",
session=None,
region="us-east-2",
)
client = vault_client.client
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)
client.auth.aws.iam_login.assert_called_with(
access_key="user",
secret_key="pass",
role="role",
region="us-east-2",
)
client.is_authenticated.assert_called_with()
assert vault_client.kv_engine_version == 2

@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_azure(self, mock_hvac):
mock_client = mock.MagicMock()
Expand Down
15 changes: 7 additions & 8 deletions providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,16 +306,15 @@ def test_aws_iam_init_params(self, mock_hvac, mock_get_connection):
"auth_type": "aws_iam",
"role_id": "role",
"session": None,
"region": "us-east-2",
}

test_hook = VaultHook(**kwargs)
mock_get_connection.assert_called_with("vault_conn_id")
test_client = test_hook.get_conn()
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)
test_client.auth.aws.iam_login.assert_called_with(
access_key="user",
secret_key="pass",
role="role",
access_key="user", secret_key="pass", role="role", region="us-east-2"
)
test_client.is_authenticated.assert_called_with()
assert test_hook.vault_client.kv_engine_version == 2
Expand All @@ -328,7 +327,7 @@ def test_aws_iam_dejson(self, mock_hvac, mock_get_connection):
mock_connection = self.get_mock_connection()
mock_get_connection.return_value = mock_connection

connection_dict = {"auth_type": "aws_iam", "role_id": "role"}
connection_dict = {"auth_type": "aws_iam", "role_id": "role", "region": "us-east-2"}

mock_connection.extra_dejson.get.side_effect = connection_dict.get
kwargs = {
Expand All @@ -344,21 +343,21 @@ def test_aws_iam_dejson(self, mock_hvac, mock_get_connection):
access_key="user",
secret_key="pass",
role="role",
region="us-east-2",
)

@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@mock.patch.dict(
"os.environ",
AIRFLOW_CONN_VAULT_CONN_ID="https://login:pass@vault.example.com?auth_type=aws_iam&role_id=role",
AIRFLOW_CONN_VAULT_CONN_ID="https://login:pass@vault.example.com?auth_type=aws_iam&role_id=role"
"&region=us-east-2",
)
def test_aws_uri(self, mock_hvac):
test_hook = VaultHook(vault_conn_id="vault_conn_id", session=None)
test_client = test_hook.get_conn()
mock_hvac.Client.assert_called_with(url="https://vault.example.com", session=None)
test_client.auth.aws.iam_login.assert_called_with(
access_key="login",
secret_key="pass",
role="role",
access_key="login", secret_key="pass", role="role", region="us-east-2"
)
test_client.is_authenticated.assert_called_with()
assert test_hook.vault_client.kv_engine_version == 2
Expand Down