Skip to content

Commit

Permalink
Convert docs_generate_tests to new framework (#5058)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank authored Apr 20, 2022
1 parent 37b8b65 commit 7d0fccd
Show file tree
Hide file tree
Showing 34 changed files with 9,772 additions and 2,436 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Under the Hood-20220413-183014.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Under the Hood
body: Convert 029_docs_generate tests to new framework
time: 2022-04-13T18:30:14.706391-04:00
custom:
Author: gshank
Issue: "5035"
PR: "5058"
32 changes: 20 additions & 12 deletions core/dbt/tests/fixtures/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ def dbt_profile_target():
}


@pytest.fixture(scope="class")
def profile_user(dbt_profile_target):
return dbt_profile_target["user"]


# This fixture can be overridden in a project. The data provided in this
# fixture will be merged into the default project dictionary via a python 'update'.
@pytest.fixture(scope="class")
Expand Down Expand Up @@ -253,16 +258,17 @@ def write_project_files(project_root, dir_name, file_dict):
# Write files out from file_dict. Can be nested directories...
def write_project_files_recursively(path, file_dict):
if type(file_dict) is not dict:
raise TestProcessingException(f"Error creating {path}. Did you forget the file extension?")
raise TestProcessingException(f"File dict is not a dict: '{file_dict}' for path '{path}'")
suffix_list = [".sql", ".csv", ".md", ".txt"]
for name, value in file_dict.items():
if name.endswith(".sql") or name.endswith(".csv") or name.endswith(".md"):
write_file(value, path, name)
elif name.endswith(".yml") or name.endswith(".yaml"):
if name.endswith(".yml") or name.endswith(".yaml"):
if isinstance(value, str):
data = value
else:
data = yaml.safe_dump(value)
write_file(data, path, name)
elif name.endswith(tuple(suffix_list)):
write_file(value, path, name)
else:
write_project_files_recursively(path.mkdir(name), value)

Expand Down Expand Up @@ -356,6 +362,7 @@ def __init__(
self.test_schema = test_schema
self.database = database
self.test_config = test_config
self.created_schemas = []

@property
def adapter(self):
Expand All @@ -377,20 +384,21 @@ def run_sql(self, sql, fetch=None):

# Create the unique test schema. Used in test setup, so that we're
# ready for initial sql prior to a run_dbt command.
def create_test_schema(self):
def create_test_schema(self, schema_name=None):
if schema_name is None:
schema_name = self.test_schema
with get_connection(self.adapter):
relation = self.adapter.Relation.create(
database=self.database, schema=self.test_schema
)
relation = self.adapter.Relation.create(database=self.database, schema=schema_name)
self.adapter.create_schema(relation)
self.created_schemas.append(schema_name)

# Drop the unique test schema, usually called in test cleanup
def drop_test_schema(self):
with get_connection(self.adapter):
relation = self.adapter.Relation.create(
database=self.database, schema=self.test_schema
)
self.adapter.drop_schema(relation)
for schema_name in self.created_schemas:
relation = self.adapter.Relation.create(database=self.database, schema=schema_name)
self.adapter.drop_schema(relation)
self.created_schemas = []

# This return a dictionary of table names to 'view' or 'table' values.
def get_tables_in_schema(self):
Expand Down
65 changes: 62 additions & 3 deletions core/dbt/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import yaml
import json
import warnings
from datetime import datetime
from typing import List
from contextlib import contextmanager

Expand Down Expand Up @@ -38,7 +39,12 @@
# get_relation_columns
# update_rows
# generate_update_clause

#
# Classes for comparing fields in dictionaries
# AnyFloat
# AnyInteger
# AnyString
# AnyStringWith
# =============================================================================


Expand Down Expand Up @@ -104,9 +110,9 @@ def copy_file(src_path, src, dest_path, dest) -> None:


# Used in tests when you want to remove a file from the project directory
def rm_file(src_path, src) -> None:
def rm_file(*paths) -> None:
# remove files from proj_path
os.remove(os.path.join(src_path, src))
os.remove(os.path.join(*paths))


# Used in tests to write out the string contents of a file to a
Expand Down Expand Up @@ -167,6 +173,16 @@ def check_result_nodes_by_unique_id(results, unique_ids):
assert set(unique_ids) == set(result_unique_ids)


# Check datetime is between start and end/now
def check_datetime_between(timestr, start, end=None):
datefmt = "%Y-%m-%dT%H:%M:%S.%fZ"
if end is None:
end = datetime.utcnow()
parsed = datetime.strptime(timestr, datefmt)
assert start <= parsed
assert end >= parsed


class TestProcessingException(Exception):
pass

Expand Down Expand Up @@ -419,3 +435,46 @@ def check_table_does_not_exist(adapter, name):
def check_table_does_exist(adapter, name):
columns = get_relation_columns(adapter, name)
assert len(columns) > 0


# Utility classes for enabling comparison of dictionaries


class AnyFloat:
"""Any float. Use this in assert calls"""

def __eq__(self, other):
return isinstance(other, float)


class AnyInteger:
"""Any Integer. Use this in assert calls"""

def __eq__(self, other):
return isinstance(other, int)


class AnyString:
"""Any string. Use this in assert calls"""

def __eq__(self, other):
return isinstance(other, str)


class AnyStringWith:
"""AnyStringWith("AUTO")"""

def __init__(self, contains=None):
self.contains = contains

def __eq__(self, other):
if not isinstance(other, str):
return False

if self.contains is None:
return True

return self.contains in other

def __repr__(self):
return "AnyStringWith<{!r}>".format(self.contains)
Loading

0 comments on commit 7d0fccd

Please sign in to comment.