-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AWS Athena profile mapping #578
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"Athena Airflow connection -> dbt profile mappings" | ||
|
||
from .access_key import AthenaAccessKeyProfileMapping | ||
|
||
__all__ = ["AthenaAccessKeyProfileMapping"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key." | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from ..base import BaseProfileMapping | ||
|
||
|
||
class AthenaAccessKeyProfileMapping(BaseProfileMapping): | ||
""" | ||
Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key. | ||
|
||
https://docs.getdbt.com/docs/core/connect-data-platform/athena-setup | ||
https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/connections/aws.html | ||
""" | ||
|
||
benjamin-awd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
airflow_connection_type: str = "aws" | ||
dbt_profile_type: str = "athena" | ||
is_community: bool = True | ||
|
||
required_fields = [ | ||
"aws_access_key_id", | ||
"aws_secret_access_key", | ||
"database", | ||
"region_name", | ||
"s3_staging_dir", | ||
"schema", | ||
] | ||
secret_fields = [ | ||
"aws_secret_access_key", | ||
] | ||
airflow_param_mapping = { | ||
"aws_access_key_id": "login", | ||
"aws_secret_access_key": "password", | ||
"aws_profile_name": "extra.aws_profile_name", | ||
"database": "extra.database", | ||
"debug_query_state": "extra.debug_query_state", | ||
"lf_tags_database": "extra.lf_tags_database", | ||
"num_retries": "extra.num_retries", | ||
"poll_interval": "extra.poll_interval", | ||
"region_name": "extra.region_name", | ||
"s3_data_dir": "extra.s3_data_dir", | ||
"s3_data_naming": "extra.s3_data_naming", | ||
"s3_staging_dir": "extra.s3_staging_dir", | ||
"schema": "extra.schema", | ||
"seed_s3_upload_args": "extra.seed_s3_upload_args", | ||
"work_group": "extra.work_group", | ||
} | ||
|
||
@property | ||
def profile(self) -> dict[str, Any | None]: | ||
"Gets profile. The password is stored in an environment variable." | ||
profile = { | ||
**self.mapped_params, | ||
**self.profile_args, | ||
# aws_secret_access_key should always get set as env var | ||
"aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"), | ||
} | ||
return self.filter_null(profile) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
"Tests for the Athena profile." | ||
|
||
import json | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from airflow.models.connection import Connection | ||
|
||
from cosmos.profiles import get_automatic_profile_mapping | ||
from cosmos.profiles.athena.access_key import AthenaAccessKeyProfileMapping | ||
|
||
|
||
@pytest.fixture() | ||
def mock_athena_conn(): # type: ignore | ||
""" | ||
Sets the connection as an environment variable. | ||
""" | ||
conn = Connection( | ||
conn_id="my_athena_connection", | ||
conn_type="aws", | ||
login="my_aws_access_key_id", | ||
password="my_aws_secret_key", | ||
extra=json.dumps( | ||
{ | ||
"database": "my_database", | ||
"region_name": "my_region", | ||
"s3_staging_dir": "s3://my_bucket/dbt/", | ||
"schema": "my_schema", | ||
} | ||
), | ||
) | ||
|
||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
yield conn | ||
|
||
|
||
def test_athena_connection_claiming() -> None: | ||
""" | ||
Tests that the Athena profile mapping claims the correct connection type. | ||
""" | ||
# should only claim when: | ||
# - conn_type == aws | ||
# and the following exist: | ||
# - login | ||
# - password | ||
# - database | ||
# - region_name | ||
# - s3_staging_dir | ||
# - schema | ||
potential_values = { | ||
"conn_type": "aws", | ||
"login": "my_aws_access_key_id", | ||
"password": "my_aws_secret_key", | ||
"extra": json.dumps( | ||
{ | ||
"database": "my_database", | ||
"region_name": "my_region", | ||
"s3_staging_dir": "s3://my_bucket/dbt/", | ||
"schema": "my_schema", | ||
} | ||
), | ||
} | ||
|
||
# if we're missing any of the values, it shouldn't claim | ||
for key in potential_values: | ||
values = potential_values.copy() | ||
del values[key] | ||
conn = Connection(**values) # type: ignore | ||
|
||
print("testing with", values) | ||
|
||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
# should raise an InvalidMappingException | ||
profile_mapping = AthenaAccessKeyProfileMapping(conn, {}) | ||
assert not profile_mapping.can_claim_connection() | ||
|
||
# if we have them all, it should claim | ||
conn = Connection(**potential_values) # type: ignore | ||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
profile_mapping = AthenaAccessKeyProfileMapping(conn, {}) | ||
assert profile_mapping.can_claim_connection() | ||
|
||
|
||
def test_athena_profile_mapping_selected( | ||
mock_athena_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that the correct profile mapping is selected for Athena. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_athena_conn.conn_id, | ||
) | ||
assert isinstance(profile_mapping, AthenaAccessKeyProfileMapping) | ||
|
||
|
||
def test_athena_profile_args( | ||
mock_athena_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that the profile values get set correctly for Athena. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_athena_conn.conn_id, | ||
) | ||
|
||
assert profile_mapping.profile == { | ||
"type": "athena", | ||
"aws_access_key_id": mock_athena_conn.login, | ||
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", | ||
"database": mock_athena_conn.extra_dejson.get("database"), | ||
"region_name": mock_athena_conn.extra_dejson.get("region_name"), | ||
"s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"), | ||
"schema": mock_athena_conn.extra_dejson.get("schema"), | ||
} | ||
|
||
|
||
def test_athena_profile_args_overrides( | ||
mock_athena_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that you can override the profile values for Athena. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_athena_conn.conn_id, | ||
profile_args={"schema": "my_custom_schema", "database": "my_custom_db"}, | ||
) | ||
assert profile_mapping.profile_args == { | ||
"schema": "my_custom_schema", | ||
"database": "my_custom_db", | ||
} | ||
|
||
assert profile_mapping.profile == { | ||
"type": "athena", | ||
"aws_access_key_id": mock_athena_conn.login, | ||
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", | ||
"database": "my_custom_db", | ||
"region_name": mock_athena_conn.extra_dejson.get("region_name"), | ||
"s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"), | ||
"schema": "my_custom_schema", | ||
} | ||
|
||
|
||
def test_athena_profile_env_vars( | ||
mock_athena_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that the environment variables get set correctly for Athena. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_athena_conn.conn_id, | ||
) | ||
assert profile_mapping.env_vars == { | ||
"COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_athena_conn.password, | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, should this have been included in the all export on line 60?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good spot.
I have an upcoming PR for the Athena profile mapping, so I'll probably squeeze the fix in there.
This shouldn't affect anyone, unless they're doing
from cosmos.profiles import *