diff --git a/dbt/adapters/athena/__init__.py b/dbt/adapters/athena/__init__.py index ebd85a49..c2f140db 100644 --- a/dbt/adapters/athena/__init__.py +++ b/dbt/adapters/athena/__init__.py @@ -1,14 +1,11 @@ -import dbt from dbt.adapters.athena.connections import AthenaConnectionManager, AthenaCredentials from dbt.adapters.athena.impl import AthenaAdapter -from dbt.adapters.athena.query_headers import _QueryComment from dbt.adapters.base import AdapterPlugin from dbt.include import athena -Plugin = AdapterPlugin(adapter=AthenaAdapter, credentials=AthenaCredentials, include_path=athena.PACKAGE_PATH) - -# overwrite _QueryComment to add leading "--" to query comment -dbt.adapters.base.query_headers._QueryComment = _QueryComment +Plugin: AdapterPlugin = AdapterPlugin( + adapter=AthenaAdapter, credentials=AthenaCredentials, include_path=athena.PACKAGE_PATH +) __all__ = [ "AthenaConnectionManager", diff --git a/dbt/adapters/athena/__version__.py b/dbt/adapters/athena/__version__.py index 2196826f..6496f3e2 100644 --- a/dbt/adapters/athena/__version__.py +++ b/dbt/adapters/athena/__version__.py @@ -1 +1 @@ -version = "1.7.2" +version = "1.8.0b1" diff --git a/dbt/adapters/athena/column.py b/dbt/adapters/athena/column.py index 46448cae..b198cbfa 100644 --- a/dbt/adapters/athena/column.py +++ b/dbt/adapters/athena/column.py @@ -1,9 +1,10 @@ from dataclasses import dataclass from typing import ClassVar, Dict +from dbt_common.exceptions import DbtRuntimeError + from dbt.adapters.athena.relation import TableType from dbt.adapters.base.column import Column -from dbt.exceptions import DbtRuntimeError @dataclass diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index e896c452..63d158a0 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -9,6 +9,8 @@ from typing import Any, ContextManager, Dict, List, Optional, Tuple import tenacity +from dbt_common.exceptions import ConnectionError, DbtRuntimeError +from dbt_common.utils import md5 from pyathena.connection import Connection as AthenaConnection from pyathena.cursor import Cursor from pyathena.error import OperationalError, ProgrammingError @@ -29,12 +31,15 @@ from dbt.adapters.athena.config import get_boto3_config from dbt.adapters.athena.constants import LOGGER +from dbt.adapters.athena.query_headers import AthenaMacroQueryStringSetter from dbt.adapters.athena.session import get_boto3_session -from dbt.adapters.base import Credentials +from dbt.adapters.contracts.connection import ( + AdapterResponse, + Connection, + ConnectionState, + Credentials, +) from dbt.adapters.sql import SQLConnectionManager -from dbt.contracts.connection import AdapterResponse, Connection, ConnectionState -from dbt.exceptions import ConnectionError, DbtRuntimeError -from dbt.utils import md5 @dataclass @@ -201,6 +206,9 @@ def inner() -> AthenaCursor: class AthenaConnectionManager(SQLConnectionManager): TYPE = "athena" + def set_query_header(self, query_header_context: Dict[str, Any]) -> None: + self.query_header = AthenaMacroQueryStringSetter(self.profile, query_header_context) + @classmethod def data_type_code_to_name(cls, type_code: str) -> str: """ diff --git a/dbt/adapters/athena/constants.py b/dbt/adapters/athena/constants.py index aa1dd5e4..9f132d54 100644 --- a/dbt/adapters/athena/constants.py +++ b/dbt/adapters/athena/constants.py @@ -1,4 +1,4 @@ -from dbt.events import AdapterLogger +from dbt.adapters.events.logging import AdapterLogger DEFAULT_THREAD_COUNT = 4 DEFAULT_RETRY_ATTEMPTS = 3 diff --git a/dbt/adapters/athena/exceptions.py b/dbt/adapters/athena/exceptions.py index fe6ded5d..adf6928f 100644 --- a/dbt/adapters/athena/exceptions.py +++ b/dbt/adapters/athena/exceptions.py @@ -1,4 +1,4 @@ -from dbt.exceptions import CompilationError, DbtRuntimeError +from dbt_common.exceptions import CompilationError, DbtRuntimeError class SnapshotMigrationRequired(CompilationError): diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index ea4377fa..316b7430 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -7,16 +7,18 @@ from dataclasses import dataclass from datetime import date, datetime from functools import lru_cache -from itertools import chain from textwrap import dedent from threading import Lock -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type +from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type from urllib.parse import urlparse from uuid import uuid4 import agate import mmh3 from botocore.exceptions import ClientError +from dbt_common.clients.agate_helper import table_from_rows +from dbt_common.contracts.constraints import ConstraintType +from dbt_common.exceptions import DbtRuntimeError from mypy_boto3_athena import AthenaClient from mypy_boto3_athena.type_defs import DataCatalogTypeDef, GetWorkGroupOutputTypeDef from mypy_boto3_glue.type_defs import ( @@ -65,13 +67,9 @@ from dbt.adapters.base import ConstraintSupport, PythonJobHelper, available from dbt.adapters.base.impl import AdapterConfig from dbt.adapters.base.relation import BaseRelation, InformationSchema +from dbt.adapters.contracts.connection import AdapterResponse +from dbt.adapters.contracts.relation import RelationConfig from dbt.adapters.sql import SQLAdapter -from dbt.clients.agate_helper import table_from_rows -from dbt.config.runtime import RuntimeConfig -from dbt.contracts.connection import AdapterResponse -from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.nodes import CompiledNode, ConstraintType -from dbt.exceptions import DbtRuntimeError boto3_client_lock = Lock() @@ -538,29 +536,6 @@ def _s3_path_exists(self, s3_bucket: str, s3_prefix: str) -> bool: response = s3_client.list_objects_v2(Bucket=s3_bucket, Prefix=s3_prefix) return True if "Contents" in response else False - def _join_catalog_table_owners(self, table: agate.Table, manifest: Manifest) -> agate.Table: - owners = [] - # Get the owner for each model from the manifest - for node in manifest.nodes.values(): - if node.resource_type == "model": - owners.append( - { - "table_database": node.database, - "table_schema": node.schema, - "table_name": node.alias, - "table_owner": node.config.meta.get("owner"), - } - ) - owners_table = agate.Table.from_object(owners) - - # Join owners with the results from catalog - join_keys = ["table_database", "table_schema", "table_name"] - return table.join( - right_table=owners_table, - left_key=join_keys, - right_key=join_keys, - ) - def _get_one_table_for_catalog(self, table: TableTypeDef, database: str) -> List[Dict[str, Any]]: table_catalog = { "table_database": database, @@ -608,13 +583,13 @@ def _get_one_table_for_non_glue_catalog( def _get_one_catalog( self, information_schema: InformationSchema, - schemas: Dict[str, Optional[Set[str]]], - manifest: Manifest, + schemas: Set[str], + used_schemas: FrozenSet[Tuple[str, str]], ) -> agate.Table: """ This function is invoked by Adapter.get_catalog for each schema. """ - data_catalog = self._get_data_catalog(information_schema.path.database) + data_catalog = self._get_data_catalog(information_schema.database) data_catalog_type = get_catalog_type(data_catalog) conn = self.connections.get_thread_connection() @@ -630,7 +605,7 @@ def _get_one_catalog( catalog = [] paginator = glue_client.get_paginator("get_tables") - for schema, relations in schemas.items(): + for schema in schemas: kwargs = { "DatabaseName": schema, "MaxResults": 100, @@ -643,8 +618,7 @@ def _get_one_catalog( for page in paginator.paginate(**kwargs): for table in page["TableList"]: - if relations and table["Name"] in relations: - catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database)) + catalog.extend(self._get_one_table_for_catalog(table, information_schema.database)) table = agate.Table.from_object(catalog) else: with boto3_client_lock: @@ -656,36 +630,28 @@ def _get_one_catalog( catalog = [] paginator = athena_client.get_paginator("list_table_metadata") - for schema, relations in schemas.items(): + for schema in schemas: for page in paginator.paginate( - CatalogName=information_schema.path.database, + CatalogName=information_schema.database, DatabaseName=schema, MaxResults=50, # Limit supported by this operation ): for table in page["TableMetadataList"]: - if relations and table["Name"].lower() in relations: - catalog.extend( - self._get_one_table_for_non_glue_catalog( - table, schema, information_schema.path.database - ) - ) + catalog.extend( + self._get_one_table_for_non_glue_catalog(table, schema, information_schema.database) + ) table = agate.Table.from_object(catalog) - filtered_table = self._catalog_filter_table(table, manifest) - return self._join_catalog_table_owners(filtered_table, manifest) + return self._catalog_filter_table(table, used_schemas) - def _get_catalog_schemas(self, manifest: Manifest) -> AthenaSchemaSearchMap: + def _get_catalog_schemas(self, relation_configs: Iterable[RelationConfig]) -> AthenaSchemaSearchMap: """ Get the schemas from the catalog. It's called by the `get_catalog` method. """ info_schema_name_map = AthenaSchemaSearchMap() - nodes: Iterator[CompiledNode] = chain( - [node for node in manifest.nodes.values() if (node.is_relational and not node.is_ephemeral_model)], - manifest.sources.values(), - ) - for node in nodes: - relation = self.Relation.create_from(self.config, node) + for relation_config in relation_configs: + relation = self.Relation.create_from(quoting=self.config, relation_config=relation_config) info_schema_name_map.add(relation) return info_schema_name_map @@ -775,9 +741,9 @@ def list_relations_without_caching(self, schema_relation: AthenaRelation) -> Lis def _get_one_catalog_by_relations( self, information_schema: InformationSchema, - relations: List[BaseRelation], - manifest: Manifest, - ) -> agate.Table: + relations: List[AthenaRelation], + used_schemas: FrozenSet[Tuple[str, str]], + ) -> "agate.Table": """ Overwrite of _get_one_catalog_by_relations for Athena, in order to use glue apis. This function is invoked by Adapter.get_catalog_by_relations. @@ -790,12 +756,11 @@ def _get_one_catalog_by_relations( _table_definitions.extend(_table_definition) table = agate.Table.from_object(_table_definitions) # picked from _catalog_filter_table, force database + schema to be strings - table_casted = table_from_rows( + return table_from_rows( table.rows, table.column_names, text_only_columns=["table_database", "table_schema", "table_name"], ) - return self._join_catalog_table_owners(table_casted, manifest) @available def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelation) -> None: @@ -1012,11 +977,9 @@ def persist_docs_to_glue( # Add some of dbt model config fields as table meta meta["unique_id"] = model.get("unique_id") meta["materialized"] = model.get("config", {}).get("materialized") - # Get dbt runtime config to be able to get dbt project metadata - runtime_config: RuntimeConfig = self.config # Add dbt project metadata to table meta - meta["dbt_project_name"] = runtime_config.project_name - meta["dbt_project_version"] = runtime_config.version + meta["dbt_project_name"] = self.config.project_name + meta["dbt_project_version"] = self.config.version # Prepare meta values for table properties and check if update is required for meta_key, meta_value_raw in meta.items(): if is_valid_table_parameter_key(meta_key): diff --git a/dbt/adapters/athena/lakeformation.py b/dbt/adapters/athena/lakeformation.py index 227fb751..86b51d01 100644 --- a/dbt/adapters/athena/lakeformation.py +++ b/dbt/adapters/athena/lakeformation.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Sequence, Set, Union +from dbt_common.exceptions import DbtRuntimeError from mypy_boto3_lakeformation import LakeFormationClient from mypy_boto3_lakeformation.type_defs import ( AddLFTagsToResourceResponseTypeDef, @@ -16,8 +17,7 @@ from pydantic import BaseModel from dbt.adapters.athena.relation import AthenaRelation -from dbt.events import AdapterLogger -from dbt.exceptions import DbtRuntimeError +from dbt.adapters.events.logging import AdapterLogger logger = AdapterLogger("AthenaLakeFormation") diff --git a/dbt/adapters/athena/python_submissions.py b/dbt/adapters/athena/python_submissions.py index 4fe8fa0f..5a3799ec 100644 --- a/dbt/adapters/athena/python_submissions.py +++ b/dbt/adapters/athena/python_submissions.py @@ -3,13 +3,13 @@ from typing import Any, Dict import botocore +from dbt_common.exceptions import DbtRuntimeError from dbt.adapters.athena.config import AthenaSparkSessionConfig from dbt.adapters.athena.connections import AthenaCredentials from dbt.adapters.athena.constants import LOGGER from dbt.adapters.athena.session import AthenaSparkSessionManager from dbt.adapters.base import PythonJobHelper -from dbt.exceptions import DbtRuntimeError SUBMISSION_LANGUAGE = "python" diff --git a/dbt/adapters/athena/query_headers.py b/dbt/adapters/athena/query_headers.py index 4fb4bff2..ab3d1a61 100644 --- a/dbt/adapters/athena/query_headers.py +++ b/dbt/adapters/athena/query_headers.py @@ -1,7 +1,16 @@ -import dbt.adapters.base.query_headers +from typing import Any, Dict +from dbt.adapters.base.query_headers import MacroQueryStringSetter, _QueryComment +from dbt.adapters.contracts.connection import AdapterRequiredConfig -class _QueryComment(dbt.adapters.base.query_headers._QueryComment): + +class AthenaMacroQueryStringSetter(MacroQueryStringSetter): + def __init__(self, config: AdapterRequiredConfig, query_header_context: Dict[str, Any]): + super().__init__(config, query_header_context) + self.comment = _AthenaQueryComment(None) + + +class _AthenaQueryComment(_QueryComment): """ Athena DDL does not always respect /* ... */ block quotations. This function is the same as _QueryComment.add except that diff --git a/dbt/adapters/athena/session.py b/dbt/adapters/athena/session.py index 39594938..b346d13e 100644 --- a/dbt/adapters/athena/session.py +++ b/dbt/adapters/athena/session.py @@ -8,6 +8,8 @@ import boto3 import boto3.session +from dbt_common.exceptions import DbtRuntimeError +from dbt_common.invocation import get_invocation_id from dbt.adapters.athena.config import get_boto3_config from dbt.adapters.athena.constants import ( @@ -15,9 +17,7 @@ LOGGER, SESSION_IDLE_TIMEOUT_MIN, ) -from dbt.contracts.connection import Connection -from dbt.events.functions import get_invocation_id -from dbt.exceptions import DbtRuntimeError +from dbt.adapters.contracts.connection import Connection invocation_id = get_invocation_id() spark_session_list: Dict[UUID, str] = {} diff --git a/dbt/include/athena/macros/utils/safe_cast.sql b/dbt/include/athena/macros/utils/safe_cast.sql index 69cb658f..aed0866e 100644 --- a/dbt/include/athena/macros/utils/safe_cast.sql +++ b/dbt/include/athena/macros/utils/safe_cast.sql @@ -1,3 +1,4 @@ +-- TODO: make safe_cast supports complex structures {% macro athena__safe_cast(field, type) -%} try_cast({{field}} as {{type}}) {%- endmacro %} diff --git a/dev-requirements.txt b/dev-requirements.txt index e97d3324..d077f94b 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,7 +1,7 @@ autoflake~=2.3 black~=24.4 boto3-stubs[s3]~=1.34 -dbt-tests-adapter~=1.7.11 +dbt-tests-adapter~=1.8.0b1 flake8~=7.0 Flake8-pyproject~=1.2 isort~=5.13 diff --git a/setup.py b/setup.py index e009e198..afe69bc3 100644 --- a/setup.py +++ b/setup.py @@ -31,17 +31,12 @@ def _get_package_version() -> str: return f'{parts["major"]}.{parts["minor"]}.{parts["patch"]}' -dbt_version = "1.7" -package_version = _get_package_version() description = "The athena adapter plugin for dbt (data build tool)" -if not package_version.startswith(dbt_version): - raise ValueError(f"Invalid setup.py: package_version={package_version} must start with dbt_version={dbt_version}") - setup( name=package_name, - version=package_version, + version=_get_package_version(), description=description, long_description=long_description, long_description_content_type="text/markdown", @@ -52,12 +47,12 @@ def _get_package_version() -> str: packages=find_namespace_packages(include=["dbt", "dbt.*"]), include_package_data=True, install_requires=[ - # In order to control dbt-core version and package version + "dbt-common>=1.0.0b2,<2.0", + "dbt-adapters>=1.0.0b2,<2.0", "boto3>=1.28", "boto3-stubs[athena,glue,lakeformation,sts]>=1.28", - "dbt-core~=1.7.0", - "mmh3>=4.0.1,<4.2.0", "pyathena>=2.25,<4.0", + "mmh3>=4.0.1,<4.2.0", "pydantic>=1.10,<3.0", "tenacity~=8.2", ], @@ -73,5 +68,5 @@ def _get_package_version() -> str: "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", ], - python_requires=">=3.8", + python_requires=">=3.8,<3.13", ) diff --git a/tests/conftest.py b/tests/conftest.py index 3701479c..e94bd8cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,13 +4,12 @@ import boto3 import pytest +from dbt_common.events import get_event_manager +from dbt_common.events.base_types import EventLevel +from dbt_common.events.logger import LineFormat, LoggerConfig, NoFilter -import dbt +from dbt.adapters.athena import connections from dbt.adapters.athena.connections import AthenaCredentials -from dbt.events.base_types import EventLevel -from dbt.events.eventmgr import LineFormat -from dbt.events.functions import EVENT_MANAGER, _get_stdout_config -from dbt.events.logger import NoFilter from .unit.constants import ( ATHENA_WORKGROUP, @@ -57,17 +56,17 @@ def dbt_debug_caplog() -> StringIO: def _setup_custom_caplog(name: str, level: EventLevel): - capture_config = _get_stdout_config( - line_format=LineFormat.PlainText, + string_buf = StringIO() + capture_config = LoggerConfig( + name=name, level=level, use_colors=False, - log_cache_events=True, + line_format=LineFormat.PlainText, + filter=NoFilter, + output_stream=string_buf, ) - capture_config.name = name - capture_config.filter = NoFilter - string_buf = StringIO() - capture_config.output_stream = string_buf - EVENT_MANAGER.add_logger(capture_config) + event_manager = get_event_manager() + event_manager.add_logger(capture_config) return string_buf @@ -77,7 +76,7 @@ def athena_client(): return mock_athena_client -@patch.object(dbt.adapters.athena.connections, "AthenaCredentials") +@patch.object(connections, "AthenaCredentials") @pytest.fixture(scope="class") def athena_credentials(): return AthenaCredentials( diff --git a/tests/functional/adapter/test_empty.py b/tests/functional/adapter/test_empty.py new file mode 100644 index 00000000..b9523881 --- /dev/null +++ b/tests/functional/adapter/test_empty.py @@ -0,0 +1,5 @@ +from dbt.tests.adapter.empty.test_empty import BaseTestEmpty + + +class TestAthenaEmpty(BaseTestEmpty): + pass diff --git a/tests/functional/adapter/test_unit_testing.py b/tests/functional/adapter/test_unit_testing.py new file mode 100644 index 00000000..5ec246c2 --- /dev/null +++ b/tests/functional/adapter/test_unit_testing.py @@ -0,0 +1,36 @@ +import pytest + +from dbt.tests.adapter.unit_testing.test_case_insensitivity import ( + BaseUnitTestCaseInsensivity, +) +from dbt.tests.adapter.unit_testing.test_invalid_input import BaseUnitTestInvalidInput +from dbt.tests.adapter.unit_testing.test_types import BaseUnitTestingTypes + + +class TestAthenaUnitTestingTypes(BaseUnitTestingTypes): + @pytest.fixture + def data_types(self): + # sql_value, yaml_value + return [ + ["1", "1"], + ["2.0", "2.0"], + ["'12345'", "12345"], + ["'string'", "string"], + ["true", "true"], + ["date '2024-04-01'", "2024-04-01"], + ["timestamp '2024-04-01 00:00:00.000'", "'2024-04-01 00:00:00.000'"], + # TODO: activate once safe_cast supports complex structures + # ["array[1, 2, 3]", "[1, 2, 3]"], + # [ + # "map(array['10', '15', '20'], array['t', 'f', NULL])", + # """'{"10: "t", "15": "f", "20": null}'""", + # ], + ] + + +class TestAthenaUnitTestCaseInsensitivity(BaseUnitTestCaseInsensivity): + pass + + +class TestAthenaUnitTestInvalidInput(BaseUnitTestInvalidInput): + pass diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index ed62c936..6640ccc9 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1,5 +1,6 @@ import datetime import decimal +from multiprocessing import get_context from unittest import mock from unittest.mock import patch @@ -7,6 +8,8 @@ import boto3 import botocore import pytest +from dbt_common.clients import agate_helper +from dbt_common.exceptions import ConnectionError, DbtRuntimeError from moto import mock_aws from moto.core import DEFAULT_ACCOUNT_ID @@ -17,13 +20,8 @@ from dbt.adapters.athena.exceptions import S3LocationException from dbt.adapters.athena.relation import AthenaRelation, TableType from dbt.adapters.athena.utils import AthenaCatalogType -from dbt.clients import agate_helper -from dbt.contracts.connection import ConnectionState -from dbt.contracts.files import FileHash -from dbt.contracts.graph.nodes import CompiledNode, DependsOn, NodeConfig -from dbt.contracts.relation import RelationType -from dbt.exceptions import ConnectionError, DbtRuntimeError -from dbt.node_types import NodeType +from dbt.adapters.contracts.connection import ConnectionState +from dbt.adapters.contracts.relation import RelationType from .constants import ( ATHENA_WORKGROUP, @@ -64,201 +62,20 @@ def setup_method(self, _): self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) self._adapter = None - self.mock_manifest = mock.MagicMock() - self.mock_manifest.get_used_schemas.return_value = { - ("awsdatacatalog", "foo"), - ("awsdatacatalog", "quux"), - ("awsdatacatalog", "baz"), - (SHARED_DATA_CATALOG_NAME, "foo"), - (FEDERATED_QUERY_CATALOG_NAME, "foo"), - } - self.mock_manifest.nodes = { - "model.root.model1": CompiledNode( - name="model1", - database="awsdatacatalog", - schema="foo", - resource_type=NodeType.Model, - unique_id="model.root.model1", - alias="bar", - fqn=["root", "model1"], - package_name="root", - refs=[], - sources=[], - depends_on=DependsOn(), - config=NodeConfig.from_dict( - { - "enabled": True, - "materialized": "table", - "persist_docs": {}, - "post-hook": [], - "pre-hook": [], - "vars": {}, - "meta": {"owner": "data-engineers"}, - "quoting": {}, - "column_types": {}, - "tags": [], - } - ), - tags=[], - path="model1.sql", - original_file_path="model1.sql", - compiled=True, - extra_ctes_injected=False, - extra_ctes=[], - checksum=FileHash.from_contents(""), - raw_code="select * from source_table", - language="", - ), - "model.root.model2": CompiledNode( - name="model2", - database="awsdatacatalog", - schema="quux", - resource_type=NodeType.Model, - unique_id="model.root.model2", - alias="bar", - fqn=["root", "model2"], - package_name="root", - refs=[], - sources=[], - depends_on=DependsOn(), - config=NodeConfig.from_dict( - { - "enabled": True, - "materialized": "table", - "persist_docs": {}, - "post-hook": [], - "pre-hook": [], - "vars": {}, - "meta": {"owner": "data-analysts"}, - "quoting": {}, - "column_types": {}, - "tags": [], - } - ), - tags=[], - path="model2.sql", - original_file_path="model2.sql", - compiled=True, - extra_ctes_injected=False, - extra_ctes=[], - checksum=FileHash.from_contents(""), - raw_code="select * from source_table", - language="", - ), - "model.root.model3": CompiledNode( - name="model2", - database="awsdatacatalog", - schema="baz", - resource_type=NodeType.Model, - unique_id="model.root.model3", - alias="qux", - fqn=["root", "model2"], - package_name="root", - refs=[], - sources=[], - depends_on=DependsOn(), - config=NodeConfig.from_dict( - { - "enabled": True, - "materialized": "table", - "persist_docs": {}, - "post-hook": [], - "pre-hook": [], - "vars": {}, - "meta": {"owner": "data-engineers"}, - "quoting": {}, - "column_types": {}, - "tags": [], - } - ), - tags=[], - path="model3.sql", - original_file_path="model3.sql", - compiled=True, - extra_ctes_injected=False, - extra_ctes=[], - checksum=FileHash.from_contents(""), - raw_code="select * from source_table", - language="", - ), - "model.root.model4": CompiledNode( - name="model4", - database=SHARED_DATA_CATALOG_NAME, - schema="foo", - resource_type=NodeType.Model, - unique_id="model.root.model4", - alias="bar", - fqn=["root", "model4"], - package_name="root", - refs=[], - sources=[], - depends_on=DependsOn(), - config=NodeConfig.from_dict( - { - "enabled": True, - "materialized": "table", - "persist_docs": {}, - "post-hook": [], - "pre-hook": [], - "vars": {}, - "meta": {"owner": "data-engineers"}, - "quoting": {}, - "column_types": {}, - "tags": [], - } - ), - tags=[], - path="model4.sql", - original_file_path="model4.sql", - compiled=True, - extra_ctes_injected=False, - extra_ctes=[], - checksum=FileHash.from_contents(""), - raw_code="select * from source_table", - language="", - ), - "model.root.model5": CompiledNode( - name="model5", - database=FEDERATED_QUERY_CATALOG_NAME, - schema="foo", - resource_type=NodeType.Model, - unique_id="model.root.model5", - alias="bar", - fqn=["root", "model5"], - package_name="root", - refs=[], - sources=[], - depends_on=DependsOn(), - config=NodeConfig.from_dict( - { - "enabled": True, - "materialized": "table", - "persist_docs": {}, - "post-hook": [], - "pre-hook": [], - "vars": {}, - "meta": {"owner": "data-engineers"}, - "quoting": {}, - "column_types": {}, - "tags": [], - } - ), - tags=[], - path="model5.sql", - original_file_path="model5.sql", - compiled=True, - extra_ctes_injected=False, - extra_ctes=[], - checksum=FileHash.from_contents(""), - raw_code="select * from source_table", - language="", - ), - } + self.used_schemas = frozenset( + { + ("awsdatacatalog", "foo"), + ("awsdatacatalog", "quux"), + ("awsdatacatalog", "baz"), + (SHARED_DATA_CATALOG_NAME, "foo"), + (FEDERATED_QUERY_CATALOG_NAME, "foo"), + } + ) @property def adapter(self): if self._adapter is None: - self._adapter = AthenaAdapter(self.config) + self._adapter = AthenaAdapter(self.config, get_context("spawn")) inject_adapter(self._adapter, AthenaPlugin) return self._adapter @@ -599,17 +416,13 @@ def test__get_one_catalog(self, mock_aws_service): mock_aws_service.create_table(table_name="bar", database_name="quux") mock_aws_service.create_table_without_type(table_name="qux", database_name="baz") mock_information_schema = mock.MagicMock() - mock_information_schema.path.database = "awsdatacatalog" + mock_information_schema.database = "awsdatacatalog" self.adapter.acquire_connection("dummy") actual = self.adapter._get_one_catalog( mock_information_schema, - { - "foo": {"bar"}, - "quux": {"bar"}, - "baz": {"qux"}, - }, - self.mock_manifest, + {"foo", "quux", "baz"}, + self.used_schemas, ) expected_column_names = ( @@ -622,17 +435,16 @@ def test__get_one_catalog(self, mock_aws_service): "column_index", "column_type", "column_comment", - "table_owner", ) expected_rows = [ - ("awsdatacatalog", "foo", "bar", "table", None, "id", 0, "string", None, "data-engineers"), - ("awsdatacatalog", "foo", "bar", "table", None, "country", 1, "string", None, "data-engineers"), - ("awsdatacatalog", "foo", "bar", "table", None, "dt", 2, "date", None, "data-engineers"), - ("awsdatacatalog", "quux", "bar", "table", None, "id", 0, "string", None, "data-analysts"), - ("awsdatacatalog", "quux", "bar", "table", None, "country", 1, "string", None, "data-analysts"), - ("awsdatacatalog", "quux", "bar", "table", None, "dt", 2, "date", None, "data-analysts"), - ("awsdatacatalog", "baz", "qux", "table", None, "id", 0, "string", None, "data-engineers"), - ("awsdatacatalog", "baz", "qux", "table", None, "country", 1, "string", None, "data-engineers"), + ("awsdatacatalog", "foo", "bar", "table", None, "id", 0, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "country", 1, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "dt", 2, "date", None), + ("awsdatacatalog", "quux", "bar", "table", None, "id", 0, "string", None), + ("awsdatacatalog", "quux", "bar", "table", None, "country", 1, "string", None), + ("awsdatacatalog", "quux", "bar", "table", None, "dt", 2, "date", None), + ("awsdatacatalog", "baz", "qux", "table", None, "id", 0, "string", None), + ("awsdatacatalog", "baz", "qux", "table", None, "country", 1, "string", None), ] assert actual.column_names == expected_column_names assert len(actual.rows) == len(expected_rows) @@ -649,7 +461,7 @@ def test__get_one_catalog_by_relations(self, mock_aws_service): mock_aws_service.create_table(table_name="bar", database_name="quux") mock_information_schema = mock.MagicMock() - mock_information_schema.path.database = "awsdatacatalog" + mock_information_schema.database = "awsdatacatalog" self.adapter.acquire_connection("dummy") @@ -669,16 +481,15 @@ def test__get_one_catalog_by_relations(self, mock_aws_service): "column_index", "column_type", "column_comment", - "table_owner", ) expected_rows = [ - ("awsdatacatalog", "foo", "bar", "table", None, "id", 0, "string", None, "data-engineers"), - ("awsdatacatalog", "foo", "bar", "table", None, "country", 1, "string", None, "data-engineers"), - ("awsdatacatalog", "foo", "bar", "table", None, "dt", 2, "date", None, "data-engineers"), + ("awsdatacatalog", "foo", "bar", "table", None, "id", 0, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "country", 1, "string", None), + ("awsdatacatalog", "foo", "bar", "table", None, "dt", 2, "date", None), ] - actual = self.adapter._get_one_catalog_by_relations(mock_information_schema, [rel_1], self.mock_manifest) + actual = self.adapter._get_one_catalog_by_relations(mock_information_schema, [rel_1], self.used_schemas) assert actual.column_names == expected_column_names assert actual.rows == expected_rows @@ -688,15 +499,13 @@ def test__get_one_catalog_shared_catalog(self, mock_aws_service): mock_aws_service.create_database("foo", catalog_id=SHARED_DATA_CATALOG_NAME) mock_aws_service.create_table(table_name="bar", database_name="foo", catalog_id=SHARED_DATA_CATALOG_NAME) mock_information_schema = mock.MagicMock() - mock_information_schema.path.database = SHARED_DATA_CATALOG_NAME + mock_information_schema.database = SHARED_DATA_CATALOG_NAME self.adapter.acquire_connection("dummy") actual = self.adapter._get_one_catalog( mock_information_schema, - { - "foo": {"bar"}, - }, - self.mock_manifest, + {"foo"}, + self.used_schemas, ) expected_column_names = ( @@ -709,12 +518,11 @@ def test__get_one_catalog_shared_catalog(self, mock_aws_service): "column_index", "column_type", "column_comment", - "table_owner", ) expected_rows = [ - ("9876543210", "foo", "bar", "table", None, "id", 0, "string", None, "data-engineers"), - ("9876543210", "foo", "bar", "table", None, "country", 1, "string", None, "data-engineers"), - ("9876543210", "foo", "bar", "table", None, "dt", 2, "date", None, "data-engineers"), + ("9876543210", "foo", "bar", "table", None, "id", 0, "string", None), + ("9876543210", "foo", "bar", "table", None, "country", 1, "string", None), + ("9876543210", "foo", "bar", "table", None, "dt", 2, "date", None), ] assert actual.column_names == expected_column_names @@ -728,7 +536,7 @@ def test__get_one_catalog_federated_query_catalog(self, mock_aws_service): catalog_name=FEDERATED_QUERY_CATALOG_NAME, catalog_type=AthenaCatalogType.LAMBDA ) mock_information_schema = mock.MagicMock() - mock_information_schema.path.database = FEDERATED_QUERY_CATALOG_NAME + mock_information_schema.database = FEDERATED_QUERY_CATALOG_NAME # Original botocore _make_api_call function orig = botocore.client.BaseClient._make_api_call @@ -768,10 +576,8 @@ def mock_athena_list_table_metadata(self, operation_name, kwarg): with patch("botocore.client.BaseClient._make_api_call", new=mock_athena_list_table_metadata): actual = self.adapter._get_one_catalog( mock_information_schema, - { - "foo": {"bar"}, - }, - self.mock_manifest, + {"foo"}, + self.used_schemas, ) expected_column_names = ( @@ -784,12 +590,11 @@ def mock_athena_list_table_metadata(self, operation_name, kwarg): "column_index", "column_type", "column_comment", - "table_owner", ) expected_rows = [ - (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "id", 0, "string", None, "data-engineers"), - (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "country", 1, "string", None, "data-engineers"), - (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "dt", 2, "date", None, "data-engineers"), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "id", 0, "string", None), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "country", 1, "string", None), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "dt", 2, "date", None), ] assert actual.column_names == expected_column_names @@ -797,34 +602,6 @@ def mock_athena_list_table_metadata(self, operation_name, kwarg): for row in actual.rows.values(): assert row.values() in expected_rows - def test__get_catalog_schemas(self): - res = self.adapter._get_catalog_schemas(self.mock_manifest) - assert len(res.keys()) == 3 - - information_schema_0 = list(res.keys())[0] - assert information_schema_0.name == "INFORMATION_SCHEMA" - assert information_schema_0.schema is None - assert information_schema_0.database == "awsdatacatalog" - relations = list(res.values())[0] - assert set(relations.keys()) == {"foo", "quux", "baz"} - assert list(relations.values()) == [{"bar"}, {"bar"}, {"qux"}] - - information_schema_1 = list(res.keys())[1] - assert information_schema_1.name == "INFORMATION_SCHEMA" - assert information_schema_1.schema is None - assert information_schema_1.database == SHARED_DATA_CATALOG_NAME - relations = list(res.values())[1] - assert set(relations.keys()) == {"foo"} - assert list(relations.values()) == [{"bar"}] - - information_schema_1 = list(res.keys())[2] - assert information_schema_1.name == "INFORMATION_SCHEMA" - assert information_schema_1.schema is None - assert information_schema_1.database == FEDERATED_QUERY_CATALOG_NAME - relations = list(res.values())[1] - assert set(relations.keys()) == {"foo"} - assert list(relations.values()) == [{"bar"}] - @mock_aws def test__get_data_catalog(self, mock_aws_service): mock_aws_service.create_data_catalog() @@ -1447,8 +1224,6 @@ def test_format_unsupported_type(self): class TestAthenaFilterCatalog: def test__catalog_filter_table(self): - manifest = mock.MagicMock() - manifest.get_used_schemas.return_value = [["a", "B"], ["a", "1234"]] column_names = ["table_name", "table_database", "table_schema", "something"] rows = [ ["foo", "a", "b", "1234"], # include @@ -1458,7 +1233,7 @@ def test__catalog_filter_table(self): ] table = agate.Table(rows, column_names, agate_helper.DEFAULT_TYPE_TESTER) - result = AthenaAdapter._catalog_filter_table(table, manifest) + result = AthenaAdapter._catalog_filter_table(table, frozenset({("a", "B"), ("a", "1234")})) assert len(result) == 3 for row in result.rows: assert isinstance(row["table_schema"], str) diff --git a/tests/unit/test_connection_manager.py b/tests/unit/test_connection_manager.py index c37a4792..fd48a62f 100644 --- a/tests/unit/test_connection_manager.py +++ b/tests/unit/test_connection_manager.py @@ -1,3 +1,4 @@ +from multiprocessing import get_context from unittest import mock import pytest @@ -20,7 +21,7 @@ def test_get_response(self, state, result): cursor.rowcount = 1 cursor.state = state cursor.data_scanned_in_bytes = 123 - cm = AthenaConnectionManager(mock.MagicMock()) + cm = AthenaConnectionManager(mock.MagicMock(), get_context("spawn")) response = cm.get_response(cursor) assert isinstance(response, AthenaAdapterResponse) assert response.code == result @@ -28,7 +29,7 @@ def test_get_response(self, state, result): assert response.data_scanned_in_bytes == 123 def test_data_type_code_to_name(self): - cm = AthenaConnectionManager(mock.MagicMock()) + cm = AthenaConnectionManager(mock.MagicMock(), get_context("spawn")) assert cm.data_type_code_to_name("array") == "ARRAY" assert cm.data_type_code_to_name("map") == "MAP" assert cm.data_type_code_to_name("DECIMAL(3, 7)") == "DECIMAL" diff --git a/tests/unit/test_query_headers.py b/tests/unit/test_query_headers.py index 2043dd17..f361fa2e 100644 --- a/tests/unit/test_query_headers.py +++ b/tests/unit/test_query_headers.py @@ -1,15 +1,15 @@ from unittest import mock -from dbt.adapters.athena import _QueryComment -from dbt.adapters.base.query_headers import MacroQueryStringSetter +from dbt.adapters.athena.query_headers import AthenaMacroQueryStringSetter +from dbt.context.manifest import generate_query_header_context from .constants import AWS_REGION, DATA_CATALOG_NAME, DATABASE_NAME from .utils import config_from_parts_or_dicts class TestQueryHeaders: - query_header = MacroQueryStringSetter( - config_from_parts_or_dicts( + def setup_method(self, _): + config = config_from_parts_or_dicts( { "name": "query_headers", "version": "0.1", @@ -29,10 +29,10 @@ class TestQueryHeaders: }, "target": "test", }, - ), - mock.MagicMock(macros={}), - ) - query_header.comment = _QueryComment(None) + ) + self.query_header = AthenaMacroQueryStringSetter( + config, generate_query_header_context(config, mock.MagicMock(macros={})) + ) def test_append_comment_with_semicolon(self): self.query_header.comment.query_comment = "executed by dbt" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6638d605..3c1f4324 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -3,11 +3,11 @@ import botocore.session import pytest +from dbt_common.exceptions import DbtRuntimeError from dbt.adapters.athena import AthenaCredentials from dbt.adapters.athena.session import AthenaSparkSessionManager, get_boto3_session -from dbt.contracts.connection import Connection -from dbt.exceptions import DbtRuntimeError +from dbt.adapters.contracts.connection import Connection class TestSession: diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 791181cf..320fdc13 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -124,7 +124,7 @@ def clear_plugin(plugin): class TestAdapterConversions: def _get_tester_for(self, column_type): - from dbt.clients import agate_helper + from dbt_common.clients import agate_helper if column_type is agate.TimeDelta: # dbt never makes this! return agate.TimeDelta()