diff --git a/.circleci/config.yml b/.circleci/config.yml index 7f31af4..8c24a96 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -30,10 +30,16 @@ jobs: python3.8 -m venv venv . venv/bin/activate pip install --upgrade pip setuptools - pip install --pre --upgrade dbt-spark[ODBC] + pip install -r dev-requirements.txt mkdir -p ~/.dbt cp integration_tests/ci/sample.profiles.yml ~/.dbt/profiles.yml + - run: + name: "Run Functional Tests" + command: | + . venv/bin/activate + python3 -m pytest tests/functional --profile databricks_sql_endpoint + - run: name: "Run Tests - dbt-utils" diff --git a/.gitignore b/.gitignore index a0e4833..14e076f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ /**/dbt_packages/ /**/logs/ /**/env/ +/**/__pycache__/ +test.env diff --git a/dbt-utils b/dbt-utils index ac072a3..fbbc0fb 160000 --- a/dbt-utils +++ b/dbt-utils @@ -1 +1 @@ -Subproject commit ac072a3c4b78d43a1c013e7de8b8fa6e290b544e +Subproject commit fbbc0fb82c9e7298cfe7fb305aa316e533977112 diff --git a/dbt_project.yml b/dbt_project.yml index 14f6bec..cda6511 100644 --- a/dbt_project.yml +++ b/dbt_project.yml @@ -1,5 +1,5 @@ name: 'spark_utils' version: '0.3.0' config-version: 2 -require-dbt-version: [">=1.0.0", "<2.0.0"] +require-dbt-version: [">=1.2.0", "<2.0.0"] macro-paths: ["macros"] \ No newline at end of file diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 0000000..866f0f3 --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,5 @@ +pytest +pyodbc==4.0.32 +git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core +git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-tests-adapter&subdirectory=tests/adapter +git+https://github.com/dbt-labs/dbt-spark.git#egg=dbt-spark[ODBC] \ No newline at end of file diff --git a/integration_tests/dbt_utils/dbt_project.yml b/integration_tests/dbt_utils/dbt_project.yml index ecfe622..18198c3 100644 --- a/integration_tests/dbt_utils/dbt_project.yml +++ b/integration_tests/dbt_utils/dbt_project.yml @@ -30,6 +30,7 @@ seeds: models: dbt_utils_integration_tests: +file_format: delta + sql: # macro doesn't work for this integration test (schema pattern) test_get_relations_by_pattern: @@ -41,13 +42,19 @@ models: test_pivot_apostrophe: +enabled: false generic_tests: - # integration test doesn't work + # default version of this integration test uses an explicit cast to 'datetime' + # which SparkSQL does not support. override with our own version test_recency: +enabled: false cross_db_utils: # integration test doesn't work test_any_value: +enabled: false - # integration test doesn't work - test_listagg: - +enabled: false \ No newline at end of file + +tests: + dbt_utils_integration_tests: + cross_db_utils: + # expect exactly two failures + # (both use "order by", which isn't supported in SparkSQL) + assert_equal_test_listagg_actual__expected: + error_if: ">2" diff --git a/integration_tests/dbt_utils/models/test_recency.sql b/integration_tests/dbt_utils/models/test_recency.sql new file mode 100644 index 0000000..d44b5bc --- /dev/null +++ b/integration_tests/dbt_utils/models/test_recency.sql @@ -0,0 +1,2 @@ +select + {{ dbt_utils.date_trunc('day', dbt_utils.current_timestamp()) }} as today diff --git a/macros/dbt_utils/cross_db_utils/concat.sql b/macros/dbt_utils/cross_db_utils/concat.sql deleted file mode 100644 index 30f1a42..0000000 --- a/macros/dbt_utils/cross_db_utils/concat.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% macro spark__concat(fields) -%} - concat({{ fields|join(', ') }}) -{%- endmacro %} diff --git a/macros/dbt_utils/cross_db_utils/dateadd.sql b/macros/dbt_utils/cross_db_utils/dateadd.sql deleted file mode 100644 index e433bc7..0000000 --- a/macros/dbt_utils/cross_db_utils/dateadd.sql +++ /dev/null @@ -1,62 +0,0 @@ -{% macro spark__dateadd(datepart, interval, from_date_or_timestamp) %} - - {%- set clock_component -%} - {# make sure the dates + timestamps are real, otherwise raise an error asap #} - to_unix_timestamp({{ spark_utils.assert_not_null('to_timestamp', from_date_or_timestamp) }}) - - to_unix_timestamp({{ spark_utils.assert_not_null('date', from_date_or_timestamp) }}) - {%- endset -%} - - {%- if datepart in ['day', 'week'] -%} - - {%- set multiplier = 7 if datepart == 'week' else 1 -%} - - to_timestamp( - to_unix_timestamp( - date_add( - {{ spark_utils.assert_not_null('date', from_date_or_timestamp) }}, - cast({{interval}} * {{multiplier}} as int) - ) - ) + {{clock_component}} - ) - - {%- elif datepart in ['month', 'quarter', 'year'] -%} - - {%- set multiplier -%} - {%- if datepart == 'month' -%} 1 - {%- elif datepart == 'quarter' -%} 3 - {%- elif datepart == 'year' -%} 12 - {%- endif -%} - {%- endset -%} - - to_timestamp( - to_unix_timestamp( - add_months( - {{ spark_utils.assert_not_null('date', from_date_or_timestamp) }}, - cast({{interval}} * {{multiplier}} as int) - ) - ) + {{clock_component}} - ) - - {%- elif datepart in ('hour', 'minute', 'second', 'millisecond', 'microsecond') -%} - - {%- set multiplier -%} - {%- if datepart == 'hour' -%} 3600 - {%- elif datepart == 'minute' -%} 60 - {%- elif datepart == 'second' -%} 1 - {%- elif datepart == 'millisecond' -%} (1/1000000) - {%- elif datepart == 'microsecond' -%} (1/1000000) - {%- endif -%} - {%- endset -%} - - to_timestamp( - {{ spark_utils.assert_not_null('to_unix_timestamp', from_date_or_timestamp) }} - + cast({{interval}} * {{multiplier}} as int) - ) - - {%- else -%} - - {{ exceptions.raise_compiler_error("macro dateadd not implemented for datepart ~ '" ~ datepart ~ "' ~ on Spark") }} - - {%- endif -%} - -{% endmacro %} diff --git a/macros/dbt_utils/cross_db_utils/datediff.sql b/macros/dbt_utils/cross_db_utils/datediff.sql deleted file mode 100644 index 0496cfa..0000000 --- a/macros/dbt_utils/cross_db_utils/datediff.sql +++ /dev/null @@ -1,107 +0,0 @@ -{% macro spark__datediff(first_date, second_date, datepart) %} - - {%- if datepart in ['day', 'week', 'month', 'quarter', 'year'] -%} - - {# make sure the dates are real, otherwise raise an error asap #} - {% set first_date = spark_utils.assert_not_null('date', first_date) %} - {% set second_date = spark_utils.assert_not_null('date', second_date) %} - - {%- endif -%} - - {%- if datepart == 'day' -%} - - datediff({{second_date}}, {{first_date}}) - - {%- elif datepart == 'week' -%} - - case when {{first_date}} < {{second_date}} - then floor(datediff({{second_date}}, {{first_date}})/7) - else ceil(datediff({{second_date}}, {{first_date}})/7) - end - - -- did we cross a week boundary (Sunday)? - + case - when {{first_date}} < {{second_date}} and dayofweek({{second_date}}) < dayofweek({{first_date}}) then 1 - when {{first_date}} > {{second_date}} and dayofweek({{second_date}}) > dayofweek({{first_date}}) then -1 - else 0 end - - {%- elif datepart == 'month' -%} - - case when {{first_date}} < {{second_date}} - then floor(months_between(date({{second_date}}), date({{first_date}}))) - else ceil(months_between(date({{second_date}}), date({{first_date}}))) - end - - -- did we cross a month boundary? - + case - when {{first_date}} < {{second_date}} and dayofmonth({{second_date}}) < dayofmonth({{first_date}}) then 1 - when {{first_date}} > {{second_date}} and dayofmonth({{second_date}}) > dayofmonth({{first_date}}) then -1 - else 0 end - - {%- elif datepart == 'quarter' -%} - - case when {{first_date}} < {{second_date}} - then floor(months_between(date({{second_date}}), date({{first_date}}))/3) - else ceil(months_between(date({{second_date}}), date({{first_date}}))/3) - end - - -- did we cross a quarter boundary? - + case - when {{first_date}} < {{second_date}} and ( - (dayofyear({{second_date}}) - (quarter({{second_date}}) * 365/4)) - < (dayofyear({{first_date}}) - (quarter({{first_date}}) * 365/4)) - ) then 1 - when {{first_date}} > {{second_date}} and ( - (dayofyear({{second_date}}) - (quarter({{second_date}}) * 365/4)) - > (dayofyear({{first_date}}) - (quarter({{first_date}}) * 365/4)) - ) then -1 - else 0 end - - {%- elif datepart == 'year' -%} - - year({{second_date}}) - year({{first_date}}) - - {%- elif datepart in ('hour', 'minute', 'second', 'millisecond', 'microsecond') -%} - - {%- set divisor -%} - {%- if datepart == 'hour' -%} 3600 - {%- elif datepart == 'minute' -%} 60 - {%- elif datepart == 'second' -%} 1 - {%- elif datepart == 'millisecond' -%} (1/1000) - {%- elif datepart == 'microsecond' -%} (1/1000000) - {%- endif -%} - {%- endset -%} - - case when {{first_date}} < {{second_date}} - then ceil(( - {# make sure the timestamps are real, otherwise raise an error asap #} - {{ spark_utils.assert_not_null('to_unix_timestamp', spark_utils.assert_not_null('to_timestamp', second_date)) }} - - {{ spark_utils.assert_not_null('to_unix_timestamp', spark_utils.assert_not_null('to_timestamp', first_date)) }} - ) / {{divisor}}) - else floor(( - {{ spark_utils.assert_not_null('to_unix_timestamp', spark_utils.assert_not_null('to_timestamp', second_date)) }} - - {{ spark_utils.assert_not_null('to_unix_timestamp', spark_utils.assert_not_null('to_timestamp', first_date)) }} - ) / {{divisor}}) - end - - {% if datepart == 'millisecond' %} - + cast(date_format({{second_date}}, 'SSS') as int) - - cast(date_format({{first_date}}, 'SSS') as int) - {% endif %} - - {% if datepart == 'microsecond' %} - {% set capture_str = '[0-9]{4}-[0-9]{2}-[0-9]{2}.[0-9]{2}:[0-9]{2}:[0-9]{2}.([0-9]{6})' %} - -- Spark doesn't really support microseconds, so this is a massive hack! - -- It will only work if the timestamp-string is of the format - -- 'yyyy-MM-dd-HH mm.ss.SSSSSS' - + cast(regexp_extract({{second_date}}, '{{capture_str}}', 1) as int) - - cast(regexp_extract({{first_date}}, '{{capture_str}}', 1) as int) - {% endif %} - - {%- else -%} - - {{ exceptions.raise_compiler_error("macro datediff not implemented for datepart ~ '" ~ datepart ~ "' ~ on Spark") }} - - {%- endif -%} - -{% endmacro %} diff --git a/macros/dbt_utils/cross_db_utils/deprecated/assert_not_null.sql b/macros/dbt_utils/cross_db_utils/deprecated/assert_not_null.sql new file mode 100644 index 0000000..cbfe19b --- /dev/null +++ b/macros/dbt_utils/cross_db_utils/deprecated/assert_not_null.sql @@ -0,0 +1,3 @@ +{% macro assert_not_null(function, arg) -%} + {{ return(adapter.dispatch('assert_not_null', 'dbt')(function, arg)) }} +{%- endmacro %} diff --git a/macros/dbt_utils/cross_db_utils/deprecated/concat.sql b/macros/dbt_utils/cross_db_utils/deprecated/concat.sql new file mode 100644 index 0000000..13c316c --- /dev/null +++ b/macros/dbt_utils/cross_db_utils/deprecated/concat.sql @@ -0,0 +1,3 @@ +{% macro spark__concat(fields) -%} + {{ return(adapter.dispatch('concat', 'dbt')(fields)) }} +{%- endmacro %} diff --git a/macros/dbt_utils/cross_db_utils/datatypes.sql b/macros/dbt_utils/cross_db_utils/deprecated/datatypes.sql similarity index 66% rename from macros/dbt_utils/cross_db_utils/datatypes.sql rename to macros/dbt_utils/cross_db_utils/deprecated/datatypes.sql index c935d02..b418221 100644 --- a/macros/dbt_utils/cross_db_utils/datatypes.sql +++ b/macros/dbt_utils/cross_db_utils/deprecated/datatypes.sql @@ -1,5 +1,5 @@ {# numeric ------------------------------------------------ #} {% macro spark__type_numeric() %} - decimal(28, 6) + {{ return(adapter.dispatch('type_numeric', 'dbt')()) }} {% endmacro %} diff --git a/macros/dbt_utils/cross_db_utils/deprecated/dateadd.sql b/macros/dbt_utils/cross_db_utils/deprecated/dateadd.sql new file mode 100644 index 0000000..964ad98 --- /dev/null +++ b/macros/dbt_utils/cross_db_utils/deprecated/dateadd.sql @@ -0,0 +1,6 @@ +{% macro spark__dateadd(datepart, interval, from_date_or_timestamp) %} + -- dispatch here gets very very confusing + -- we just need to hint to dbt that this is a required macro for resolving dbt.spark__datediff() + -- {{ assert_not_null() }} + {{ return(adapter.dispatch('dateadd', 'dbt')(datepart, interval, from_date_or_timestamp)) }} +{% endmacro %} diff --git a/macros/dbt_utils/cross_db_utils/deprecated/datediff.sql b/macros/dbt_utils/cross_db_utils/deprecated/datediff.sql new file mode 100644 index 0000000..46b406f --- /dev/null +++ b/macros/dbt_utils/cross_db_utils/deprecated/datediff.sql @@ -0,0 +1,6 @@ +{% macro spark__datediff(first_date, second_date, datepart) %} + -- dispatch here gets very very confusing + -- we just need to hint to dbt that this is a required macro for resolving dbt.spark__datediff() + -- {{ assert_not_null() }} + {{ return(adapter.dispatch('datediff', 'dbt')(first_date, second_date, datepart)) }} +{% endmacro %} diff --git a/macros/dbt_utils/cross_db_utils/deprecated/split_part.sql b/macros/dbt_utils/cross_db_utils/deprecated/split_part.sql new file mode 100644 index 0000000..114a131 --- /dev/null +++ b/macros/dbt_utils/cross_db_utils/deprecated/split_part.sql @@ -0,0 +1,3 @@ +{% macro spark__split_part(string_text, delimiter_text, part_number) %} + {{ return(adapter.dispatch('split_part', 'dbt')(string_text, delimiter_text, part_number)) }} +{% endmacro %} diff --git a/macros/dbt_utils/cross_db_utils/split_part.sql b/macros/dbt_utils/cross_db_utils/split_part.sql deleted file mode 100644 index b476e05..0000000 --- a/macros/dbt_utils/cross_db_utils/split_part.sql +++ /dev/null @@ -1,23 +0,0 @@ -{% macro spark__split_part(string_text, delimiter_text, part_number) %} - - {% set delimiter_expr %} - - -- escape if starts with a special character - case when regexp_extract({{ delimiter_text }}, '([^A-Za-z0-9])(.*)', 1) != '_' - then concat('\\', {{ delimiter_text }}) - else {{ delimiter_text }} end - - {% endset %} - - {% set split_part_expr %} - - split( - {{ string_text }}, - {{ delimiter_expr }} - )[({{ part_number - 1 }})] - - {% endset %} - - {{ return(split_part_expr) }} - -{% endmacro %} diff --git a/macros/etc/assert_not_null.sql b/macros/etc/assert_not_null.sql deleted file mode 100644 index e4692de..0000000 --- a/macros/etc/assert_not_null.sql +++ /dev/null @@ -1,9 +0,0 @@ -{% macro assert_not_null(function, arg) -%} - {{ return(adapter.dispatch('assert_not_null', 'spark_utils')(function, arg)) }} -{%- endmacro %} - -{% macro default__assert_not_null(function, arg) %} - - coalesce({{function}}({{arg}}), nvl2({{function}}({{arg}}), assert_true({{function}}({{arg}}) is not null), null)) - -{% endmacro %} diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..c0ef765 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +[pytest] +filterwarnings = + ignore:.*'soft_unicode' has been renamed to 'soft_str'*:DeprecationWarning + ignore:unclosed file .*:ResourceWarning +env_files = + test.env +testpaths = + tests/functional diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0c62471 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,111 @@ +import pytest +import os + +pytest_plugins = ["dbt.tests.fixtures.project"] + + +def pytest_addoption(parser): + parser.addoption("--profile", action="store", default="apache_spark", type=str) + + +# Using @pytest.mark.skip_profile('apache_spark') uses the 'skip_by_profile_type' +# autouse fixture below +def pytest_configure(config): + config.addinivalue_line( + "markers", + "skip_profile(profile): skip test for the given profile", + ) + + +@pytest.fixture(scope="session") +def dbt_profile_target(request): + profile_type = request.config.getoption("--profile") + if profile_type == "databricks_cluster": + target = databricks_cluster_target() + elif profile_type == "databricks_sql_endpoint": + target = databricks_sql_endpoint_target() + elif profile_type == "apache_spark": + target = apache_spark_target() + elif profile_type == "databricks_http_cluster": + target = databricks_http_cluster_target() + elif profile_type == "spark_session": + target = spark_session_target() + else: + raise ValueError(f"Invalid profile type '{profile_type}'") + return target + + +def apache_spark_target(): + return { + "type": "spark", + "host": "localhost", + "user": "dbt", + "method": "thrift", + "port": 10000, + "connect_retries": 3, + "connect_timeout": 5, + "retry_all": True, + } + + +def databricks_cluster_target(): + return { + "type": "spark", + "method": "odbc", + "host": os.getenv("DBT_DATABRICKS_HOST_NAME"), + "cluster": os.getenv("DBT_DATABRICKS_CLUSTER_NAME"), + "token": os.getenv("DBT_DATABRICKS_TOKEN"), + "driver": os.getenv("ODBC_DRIVER"), + "port": 443, + "connect_retries": 3, + "connect_timeout": 5, + "retry_all": True, + } + + +def databricks_sql_endpoint_target(): + return { + "type": "spark", + "method": "odbc", + "host": os.getenv("DBT_DATABRICKS_HOST_NAME"), + "endpoint": os.getenv("DBT_DATABRICKS_ENDPOINT"), + "token": os.getenv("DBT_DATABRICKS_TOKEN"), + "driver": os.getenv("ODBC_DRIVER"), + "port": 443, + "connect_retries": 3, + "connect_timeout": 5, + "retry_all": True, + } + + +def databricks_http_cluster_target(): + return { + "type": "spark", + "host": os.getenv('DBT_DATABRICKS_HOST_NAME'), + "cluster": os.getenv('DBT_DATABRICKS_CLUSTER_NAME'), + "token": os.getenv('DBT_DATABRICKS_TOKEN'), + "method": "http", + "port": 443, + # more retries + longer timout to handle unavailability while cluster is restarting + # return failures quickly in dev, retry all failures in CI (up to 5 min) + "connect_retries": 5, + "connect_timeout": 60, + "retry_all": bool(os.getenv('DBT_DATABRICKS_RETRY_ALL', False)), + } + + +def spark_session_target(): + return { + "type": "spark", + "host": "localhost", + "method": "session", + } + + +@pytest.fixture(autouse=True) +def skip_by_profile_type(request): + profile_type = request.config.getoption("--profile") + if request.node.get_closest_marker("skip_profile"): + for skip_profile_type in request.node.get_closest_marker("skip_profile").args: + if skip_profile_type == profile_type: + pytest.skip("skipped on '{profile_type}' profile") diff --git a/tests/functional/test_utils.py b/tests/functional/test_utils.py new file mode 100644 index 0000000..9353f66 --- /dev/null +++ b/tests/functional/test_utils.py @@ -0,0 +1,72 @@ +import os +import pytest +from dbt.tests.util import run_dbt + +from dbt.tests.adapter.utils.base_utils import BaseUtils +from dbt.tests.adapter.utils.test_concat import BaseConcat +from dbt.tests.adapter.utils.test_dateadd import BaseDateAdd +from dbt.tests.adapter.utils.test_datediff import BaseDateDiff +from dbt.tests.adapter.utils.test_split_part import BaseSplitPart + +from dbt.tests.adapter.utils.data_types.base_data_type_macro import BaseDataTypeMacro +from dbt.tests.adapter.utils.data_types.test_type_numeric import BaseTypeNumeric + + +class BaseSparkUtilsBackCompat: + # install this repo as a package + @pytest.fixture(scope="class") + def packages(self): + return { + "packages": [ + {"local": os.getcwd()}, + {"git": "https://github.com/dbt-labs/dbt-utils"} + ]} + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "dispatch": [{ + "macro_namespace": "dbt_utils", + "search_order": ["spark_utils", "dbt_utils"] + }] + } + + # call the macros from the 'dbt_utils' namespace + # instead of the unspecified / global namespace + def macro_namespace(self): + return "dbt_utils" + + +class BaseSparkUtilsBackCompatUtil(BaseSparkUtilsBackCompat, BaseUtils): + # actual test sequence needs to run 'deps' first + def test_build_assert_equal(self, project): + run_dbt(['deps']) + super().test_build_assert_equal(project) + + +class BaseSparkUtilsBackCompatDataType(BaseSparkUtilsBackCompat, BaseDataTypeMacro): + # actual test sequence needs to run 'deps' first + def test_check_types_assert_match(self, project): + run_dbt(['deps']) + super().test_check_types_assert_match(project) + + +class TestConcat(BaseSparkUtilsBackCompatUtil, BaseConcat): + pass + + +class TestDateAdd(BaseSparkUtilsBackCompatUtil, BaseDateAdd): + pass + + +class TestDateDiff(BaseSparkUtilsBackCompatUtil, BaseDateDiff): + pass + + +class TestSplitPart(BaseSparkUtilsBackCompatUtil, BaseSplitPart): + pass + + +class TestTypeNumeric(BaseSparkUtilsBackCompatDataType, BaseTypeNumeric): + def numeric_fixture_type(self): + return "decimal(28,6)"