Skip to content

Commit

Permalink
get tests working
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle Wigley committed Feb 17, 2021
1 parent e08a23e commit 9bbc61b
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 28 deletions.
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ pytest-xdist>=2.1.0,<3
flaky>=3.5.3,<4

# Test requirements
git+https://github.com/fishtown-analytics/dbt-adapter-tests.git@feature/add-integration-test-tools
git+https://github.com/fishtown-analytics/dbt-adapter-tests.git@33872d1cc0f936677dae091c3e0b49771c280514
sasl==0.2.1
thrift_sasl==0.4.1
107 changes: 98 additions & 9 deletions test/custom/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,69 @@
from dbt_adapter_tests import DBTIntegrationTestBase, use_profile
import pytest
from functools import wraps
import os
from dbt_adapter_tests import DBTIntegrationTestBase


class DBTSparkIntegrationTest(DBTIntegrationTestBase):


def get_profile(self, adapter_type):
if adapter_type == 'apache_spark':
return self.apache_spark_profile()
elif adapter_type == 'databricks_cluster':
return self.databricks_cluster_profile()
elif adapter_type == 'databricks_sql_endpoint':
return self.databricks_sql_endpoint_profile()
else:
raise ValueError('invalid adapter type {}'.format(adapter_type))

@staticmethod
def _profile_from_test_name(test_name):
adapter_names = ('apache_spark', 'databricks_cluster',
'databricks_sql_endpoint')
adapters_in_name = sum(x in test_name for x in adapter_names)
if adapters_in_name != 1:
raise ValueError(
'test names must have exactly 1 profile choice embedded, {} has {}'
.format(test_name, adapters_in_name)
)

for adapter_name in adapter_names:
if adapter_name in test_name:
return adapter_name

raise ValueError(
'could not find adapter name in test name {}'.format(test_name)
)

def run_sql(self, query, fetch='None', kwargs=None, connection_name=None):
if connection_name is None:
connection_name = '__test'

if query.strip() == "":
return

sql = self.transform_sql(query, kwargs=kwargs)

with self.get_connection(connection_name) as conn:
cursor = conn.handle.cursor()
try:
cursor.execute(sql)
if fetch == 'one':
return cursor.fetchall()[0]
elif fetch == 'all':
return cursor.fetchall()
else:
# we have to fetch.
cursor.fetchall()
except Exception as e:
conn.handle.rollback()
conn.transaction_open = False
print(sql)
print(e)
raise
else:
conn.transaction_open = False

def apache_spark_profile(self):
return {
'config': {
Expand All @@ -14,13 +76,13 @@ def apache_spark_profile(self):
'host': 'localhost',
'user': 'dbt',
'method': 'thrift',
'port': '10000',
'connect_retries': '5',
'connect_timeout': '60',
'port': 10000,
'connect_retries': 5,
'connect_timeout': 60,
'schema': self.unique_schema()
},
},
'target': 'default2'
}
}
}

Expand All @@ -40,11 +102,11 @@ def databricks_cluster_profile(self):
'port': 443,
'schema': self.unique_schema()
},
},
'target': 'odbc'
}
}
}

def databricks_sql_endpoint_profile(self):
return {
'config': {
Expand All @@ -61,7 +123,34 @@ def databricks_sql_endpoint_profile(self):
'port': 443,
'schema': self.unique_schema()
},
},
'target': 'default2'
}
}
}


def use_profile(profile_name):
"""A decorator to declare a test method as using a particular profile.
Handles both setting the nose attr and calling self.use_profile.
Use like this:
class TestSomething(DBIntegrationTest):
@use_profile('postgres')
def test_postgres_thing(self):
self.assertEqual(self.adapter_type, 'postgres')
@use_profile('snowflake')
def test_snowflake_thing(self):
self.assertEqual(self.adapter_type, 'snowflake')
"""
def outer(wrapped):
@getattr(pytest.mark, 'profile_'+profile_name)
@wraps(wrapped)
def func(self, *args, **kwargs):
return wrapped(self, *args, **kwargs)
# sanity check at import time
assert DBTSparkIntegrationTest._profile_from_test_name(
wrapped.__name__) == profile_name
return func
return outer
10 changes: 10 additions & 0 deletions test/custom/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def pytest_configure(config):
config.addinivalue_line(
"markers", "profile_databricks_cluster"
)
config.addinivalue_line(
"markers", "profile_databricks_sql_endpoint"
)
config.addinivalue_line(
"markers", "profile_apache_spark"
)
45 changes: 27 additions & 18 deletions test/custom/incremental_strategies/test_incremental_strategies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from test.custom.base import DBTSparkIntegrationTest
from test.custom.base import DBTSparkIntegrationTest, use_profile
import dbt.exceptions


class TestIncrementalStrategies(DBTSparkIntegrationTest):
@property
Expand All @@ -14,73 +16,80 @@ def run_and_test(self):
self.run_dbt(["run"])
self.assertTablesEqual("default_append", "expected_append")


class TestDefaultAppend(TestIncrementalStrategies):
@use_profile("apache_spark")
def test_default_append_apache_spark(self):
self.run_and_test()

@use_profile("databricks_cluster")
def test_default_append_databricks(self):
def test_default_append_databricks_cluster(self):
self.run_and_test()


class TestInsertOverwrite(TestIncrementalStrategies):
@property
def models(self):
return "models_insert_overwrite"

def run_and_test(self):
self.run_dbt(["seed"])
self.run_dbt(["run"])
self.assertTablesEqual("insert_overwrite_no_partitions", "expected_overwrite")
self.assertTablesEqual("insert_overwrite_partitions", "expected_upsert")

self.assertTablesEqual(
"insert_overwrite_no_partitions", "expected_overwrite")
self.assertTablesEqual(
"insert_overwrite_partitions", "expected_upsert")

@use_profile("apache_spark")
def test_insert_overwrite_apache_spark(self):
self.run_and_test()

@use_profile("databricks_cluster")
def test_insert_overwrite_databricks(self):
def test_insert_overwrite_databricks_cluster(self):
self.run_and_test()


class TestDeltaStrategies(TestIncrementalStrategies):
@property
def models(self):
return "models_delta"

def run_and_test(self):
self.run_dbt(["seed"])
self.run_dbt(["run"])
self.assertTablesEqual("append_delta", "expected_append")
self.assertTablesEqual("merge_no_key", "expected_append")
self.assertTablesEqual("merge_unique_key", "expected_upsert")

@use_profile("databricks_cluster")
def test_delta_strategies_databricks(self):
def test_delta_strategies_databricks_cluster(self):
self.run_and_test()


class TestBadStrategies(TestIncrementalStrategies):
@property
def models(self):
return "models_insert_overwrite"

def run_and_test(self):
with self.assertRaises(dbt.exceptions.Exception) as exc:
self.run_dbt(["compile"])
message = str(exc.exception)
self.assertIn("Invalid file format provided", message)
self.assertIn("Invalid incremental strategy provided", message)

@use_profile("apache_spark")
def test_bad_strategies_apache_spark(self):
self.run_and_test()

@use_profile("databricks_cluster")
def test_bad_strategies_databricks(self):
def test_bad_strategies_databricks_cluster(self):
self.run_and_test()



class TestBadStrategyWithEndpoint(TestInsertOverwrite):
@use_profile("databricks_sql_endpoint")
def run_and_test(self):
def test_bad_strategies_databricks_sql_endpoint(self):
with self.assertRaises(dbt.exceptions.Exception) as exc:
self.run_dbt(["compile"], "--target", "odbc-sql-endpoint")
message = str(exc.exception)
Expand Down

0 comments on commit 9bbc61b

Please sign in to comment.