Skip to content

Commit

Permalink
Allow setup endpoint_url per-service in AWS Connection (#34593)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Sep 28, 2023
1 parent d0f2463 commit dd325b4
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 71 deletions.
70 changes: 47 additions & 23 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,11 @@ def _refresh_credentials(self) -> dict[str, Any]:
if assume_role_method not in ("assume_role", "assume_role_with_saml"):
raise NotImplementedError(f"assume_role_method={assume_role_method} not expected")

sts_client = self.basic_session.client("sts", config=self.config)
sts_client = self.basic_session.client(
"sts",
config=self.config,
endpoint_url=self.conn.get_service_endpoint_url("sts", sts_connection_assume=True),
)

if assume_role_method == "assume_role":
sts_response = self._assume_role(sts_client=sts_client)
Expand Down Expand Up @@ -558,10 +562,33 @@ def conn_config(self) -> AwsConnectionWrapper:
conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify
)

def _resolve_service_name(self, is_resource_type: bool = False) -> str:
"""Resolve service name based on type or raise an error."""
if exactly_one(self.client_type, self.resource_type):
# It is possible to write simple conditions, however it make mypy unhappy.
if self.client_type:
if is_resource_type:
raise LookupError("Requested `resource_type`, but `client_type` was set instead.")
return self.client_type
elif self.resource_type:
if not is_resource_type:
raise LookupError("Requested `client_type`, but `resource_type` was set instead.")
return self.resource_type

raise ValueError(
f"Either client_type={self.client_type!r} or "
f"resource_type={self.resource_type!r} must be provided, not both."
)

@property
def service_name(self) -> str:
"""Extracted botocore/boto3 service name from hook parameters."""
return self._resolve_service_name(is_resource_type=bool(self.resource_type))

@property
def service_config(self) -> dict:
service_name = self.client_type or self.resource_type
return self.conn_config.get_service_config(service_name)
"""Config for hook-specific service from AWS Connection."""
return self.conn_config.get_service_config(service_name=self.service_name)

@property
def region_name(self) -> str | None:
Expand Down Expand Up @@ -609,19 +636,20 @@ def get_client_type(
deferrable: bool = False,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session."""
client_type = self.client_type
service_name = self._resolve_service_name(is_resource_type=False)
session = self.get_session(region_name=region_name, deferrable=deferrable)
endpoint_url = self.conn_config.get_service_endpoint_url(service_name=service_name)
if not isinstance(session, boto3.session.Session):
return session.create_client(
client_type,
endpoint_url=self.conn_config.endpoint_url,
service_name=service_name,
endpoint_url=endpoint_url,
config=self._get_config(config),
verify=self.verify,
)

return session.client(
client_type,
endpoint_url=self.conn_config.endpoint_url,
service_name=service_name,
endpoint_url=endpoint_url,
config=self._get_config(config),
verify=self.verify,
)
Expand All @@ -632,11 +660,11 @@ def get_resource_type(
config: Config | None = None,
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session."""
resource_type = self.resource_type
service_name = self._resolve_service_name(is_resource_type=True)
session = self.get_session(region_name=region_name)
return session.resource(
resource_type,
endpoint_url=self.conn_config.endpoint_url,
service_name=service_name,
endpoint_url=self.conn_config.get_service_endpoint_url(service_name=service_name),
config=self._get_config(config),
verify=self.verify,
)
Expand All @@ -648,15 +676,9 @@ def conn(self) -> BaseAwsConnection:
:return: boto3.client or boto3.resource
"""
if not exactly_one(self.client_type, self.resource_type):
raise ValueError(
f"Either client_type={self.client_type!r} or "
f"resource_type={self.resource_type!r} must be provided, not both."
)
elif self.client_type:
if self.client_type:
return self.get_client_type(region_name=self.region_name)
else:
return self.get_resource_type(region_name=self.region_name)
return self.get_resource_type(region_name=self.region_name)

@property
def async_conn(self):
Expand Down Expand Up @@ -730,7 +752,10 @@ def expand_role(self, role: str, region_name: str | None = None) -> str:
else:
session = self.get_session(region_name=region_name)
_client = session.client(
"iam", endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify
service_name="iam",
endpoint_url=self.conn_config.get_service_endpoint_url("iam"),
config=self.config,
verify=self.verify,
)
return _client.get_role(RoleName=role)["Role"]["Arn"]

Expand Down Expand Up @@ -799,10 +824,9 @@ def test_connection(self):
"""
try:
session = self.get_session()
test_endpoint_url = self.conn_config.extra_config.get("test_endpoint_url")
conn_info = session.client(
"sts",
endpoint_url=test_endpoint_url,
service_name="sts",
endpoint_url=self.conn_config.get_service_endpoint_url("sts", sts_test_connection=True),
).get_caller_identity()
metadata = conn_info.pop("ResponseMetadata", {})
if metadata.get("HTTPStatusCode") != 200:
Expand Down
42 changes: 40 additions & 2 deletions airflow/providers/amazon/aws/utils/connection_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,48 @@ class AwsConnectionWrapper(LoggingMixin):
assume_role_method: str | None = field(init=False, default=None)
assume_role_kwargs: dict[str, Any] = field(init=False, default_factory=dict)

# Per AWS Service configuration dictionary where key is name of boto3 ``service_name``
service_config: dict[str, dict[str, Any]] = field(init=False, default_factory=dict)

@cached_property
def conn_repr(self):
return f"AWS Connection (conn_id={self.conn_id!r}, conn_type={self.conn_type!r})"

def get_service_config(self, service_name):
return self.extra_dejson.get("service_config", {}).get(service_name, {})
def get_service_config(self, service_name: str) -> dict[str, Any]:
"""Get AWS Service related config dictionary.
:param service_name: Name of botocore/boto3 service.
"""
return self.service_config.get(service_name, {})

def get_service_endpoint_url(
self, service_name: str, *, sts_connection_assume: bool = False, sts_test_connection: bool = False
) -> str | None:
service_config = self.get_service_config(service_name=service_name)
global_endpoint_url = self.endpoint_url

if service_name == "sts" and True in (sts_connection_assume, sts_test_connection):
# There are different logics exists historically for STS Client
# 1. For assume role we never use global endpoint_url
# 2. For test connection we also use undocumented `test_endpoint`\
# 3. For STS as service we might use endpoint_url (default for other services)
global_endpoint_url = None
if sts_connection_assume and sts_test_connection:
raise AirflowException(
"Can't resolve STS endpoint when both "
"`sts_connection` and `sts_test_connection` set to True."
)
elif sts_test_connection:
if "test_endpoint_url" in self.extra_config:
warnings.warn(
"extra['test_endpoint_url'] is deprecated and will be removed in a future release."
" Please set `endpoint_url` in `service_config.sts` within `extras`.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
global_endpoint_url = self.extra_config["test_endpoint_url"]

return service_config.get("endpoint_url", global_endpoint_url)

def __post_init__(self, conn: Connection):
if isinstance(conn, type(self)):
Expand Down Expand Up @@ -182,6 +218,8 @@ def __post_init__(self, conn: Connection):
)

extra = deepcopy(conn.extra_dejson)
self.service_config = extra.get("service_config", {})

session_kwargs = extra.get("session_kwargs", {})
if session_kwargs:
warnings.warn(
Expand Down
36 changes: 35 additions & 1 deletion docs/apache-airflow-providers-amazon/connections/aws.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ Extra (optional)
* ``config_kwargs``: Additional **kwargs** used to construct a
`botocore.config.Config <https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html>`__.
To anonymously access public AWS resources (equivalent of `signature_version=botocore.UNSGINED`), set `"signature_version"="unsigned"` within `config_kwargs`.
* ``endpoint_url``: Endpoint URL for the connection.
* ``endpoint_url``: Global Endpoint URL for the connection. You could specify endpoint url per AWS service by utilize
``service_config``, for more details please refer to :ref:`howto/connection:aws:per-service-endpoint-configuration`

* ``verify``: Whether or not to verify SSL certificates.

The following extra parameters used for specific AWS services:
Expand Down Expand Up @@ -343,6 +345,38 @@ The following settings may be used within the ``assume_role_with_saml`` containe
Per-service configuration
^^^^^^^^^^^^^^^^^^^^^^^^^

.. _howto/connection:aws:per-service-endpoint-configuration:

AWS Service Endpoint URL configuration
""""""""""""""""""""""""""""""""""""""

To use ``endpoint_url`` per specific AWS service in the single connection you might setup it in service config.
For enforce to default ``botocore``/``boto3`` behaviour you might set value to ``null``.
The precedence rules are as follows:

1. ``endpoint_url`` specified per service level.
2. ``endpoint_url`` specified in root level of connection extra. Please note that **sts** client which are
uses in assume role or test connection do not use global parameter.
3. Default ``botocore``/``boto3`` behaviour


.. code-block:: json
{
"endpoint_url": "s3.amazonaws.com"
"service_config": {
"s3": {
"endpoint_url": "https://s3.eu-west-1.amazonaws.com"
},
"sts": {
"endpoint_url": "https://sts.eu-west-2.amazonaws.com"
},
"ec2": {
"endpoint_url": null
}
}
}
S3 Bucket configurations
""""""""""""""""""""""""

Expand Down
Loading

0 comments on commit dd325b4

Please sign in to comment.