diff --git a/Makefile b/Makefile index 36197c4..1fb21a2 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ check: check-fmt check-imports check-lint-python check-type test: rm -rf tests/tmp - python3 -m unittest tests + pytest tests .PHONY: test pre: fix check test diff --git a/requirements-test.txt b/requirements-test.txt index 1f959e9..8eba1f5 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -5,6 +5,7 @@ black>=23.11.0 isort>=5.12.0 pylint>=3.0.2 mypy>=1.7.1 +pytest>=8.3.1 molot~=1.0.0 dbt-postgres~=1.8.1 python-dotenv~=1.0.1 diff --git a/tests/__init__.py b/tests/__init__.py index c1cf417..cc7072f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,11 +1,10 @@ import logging from dbtmetabase.format import setup_logging - -from .test_exposures import * -from .test_format import * -from .test_manifest import * -from .test_metabase import * -from .test_models import * +from tests.test_exposures import * +from tests.test_format import * +from tests.test_manifest import * +from tests.test_metabase import * +from tests.test_models import * setup_logging(level=logging.DEBUG, path=None) diff --git a/tests/test_exposures.py b/tests/test_exposures.py index 322ac44..cdf8790 100644 --- a/tests/test_exposures.py +++ b/tests/test_exposures.py @@ -1,83 +1,86 @@ -import unittest from operator import itemgetter from pathlib import Path +import pytest import yaml -from ._mocks import FIXTURES_PATH, TMP_PATH, MockDbtMetabase - - -class TestExposures(unittest.TestCase): - def setUp(self): - self.c = MockDbtMetabase() - TMP_PATH.mkdir(exist_ok=True) - - def _assert_exposures(self, expected_path: Path, actual_path: Path): - with open(expected_path, encoding="utf-8") as f: - expected = yaml.safe_load(f) - with open(actual_path, encoding="utf-8") as f: - actual = yaml.safe_load(f) - - self.assertEqual( - sorted(expected["exposures"], key=itemgetter("name")), - actual["exposures"], - ) - - def test_exposures(self): - fixtures_path = FIXTURES_PATH / "exposure" / "default" - output_path = TMP_PATH / "exposure" / "default" - self.c.extract_exposures( - output_path=str(output_path), - output_grouping=None, - ) - - self._assert_exposures( - fixtures_path / "exposures.yml", - output_path / "exposures.yml", - ) - - def test_exposures_collection_grouping(self): - fixtures_path = FIXTURES_PATH / "exposure" / "collection" - output_path = TMP_PATH / "exposure" / "collection" - self.c.extract_exposures( - output_path=str(output_path), - output_grouping="collection", - ) - - for file in fixtures_path.iterdir(): - self._assert_exposures(file, output_path / file.name) - - def test_exposures_grouping_type(self): - fixtures_path = FIXTURES_PATH / "exposure" / "type" - output_path = TMP_PATH / "exposure" / "type" - self.c.extract_exposures( - output_path=str(output_path), - output_grouping="type", - ) - - for file in (fixtures_path / "card").iterdir(): - self._assert_exposures(file, output_path / "card" / file.name) - - for file in (fixtures_path / "dashboard").iterdir(): - self._assert_exposures(file, output_path / "dashboard" / file.name) - - def test_exposures_aliased_ref(self): - for model in self.c.manifest.read_models(): - if not model.name.startswith("stg_"): - model.alias = f"{model.name}_alias" - - aliases = [m.alias for m in self.c.manifest.read_models()] - self.assertIn("orders_alias", aliases) - self.assertIn("customers_alias", aliases) - - fixtures_path = FIXTURES_PATH / "exposure" / "default" - output_path = TMP_PATH / "exposure" / "aliased" - self.c.extract_exposures( - output_path=str(output_path), - output_grouping=None, - ) - - self._assert_exposures( - fixtures_path / "exposures.yml", - output_path / "exposures.yml", - ) +from tests._mocks import FIXTURES_PATH, TMP_PATH, MockDbtMetabase + +TMP_PATH.mkdir(exist_ok=True) + + +@pytest.fixture(name="core") +def fixture_core() -> MockDbtMetabase: + return MockDbtMetabase() + + +def _assert_exposures(expected_path: Path, actual_path: Path): + with open(expected_path, encoding="utf-8") as f: + expected = yaml.safe_load(f) + with open(actual_path, encoding="utf-8") as f: + actual = yaml.safe_load(f) + + assert actual["exposures"] == sorted(expected["exposures"], key=itemgetter("name")) + + +def test_exposures(core: MockDbtMetabase): + fixtures_path = FIXTURES_PATH / "exposure" / "default" + output_path = TMP_PATH / "exposure" / "default" + core.extract_exposures( + output_path=str(output_path), + output_grouping=None, + ) + + _assert_exposures( + fixtures_path / "exposures.yml", + output_path / "exposures.yml", + ) + + +def test_exposures_collection_grouping(core: MockDbtMetabase): + fixtures_path = FIXTURES_PATH / "exposure" / "collection" + output_path = TMP_PATH / "exposure" / "collection" + core.extract_exposures( + output_path=str(output_path), + output_grouping="collection", + ) + + for file in fixtures_path.iterdir(): + _assert_exposures(file, output_path / file.name) + + +def test_exposures_grouping_type(core: MockDbtMetabase): + fixtures_path = FIXTURES_PATH / "exposure" / "type" + output_path = TMP_PATH / "exposure" / "type" + core.extract_exposures( + output_path=str(output_path), + output_grouping="type", + ) + + for file in (fixtures_path / "card").iterdir(): + _assert_exposures(file, output_path / "card" / file.name) + + for file in (fixtures_path / "dashboard").iterdir(): + _assert_exposures(file, output_path / "dashboard" / file.name) + + +def test_exposures_aliased_ref(core: MockDbtMetabase): + for model in core.manifest.read_models(): + if not model.name.startswith("stg_"): + model.alias = f"{model.name}_alias" + + aliases = [m.alias for m in core.manifest.read_models()] + assert "orders_alias" in aliases + assert "customers_alias" in aliases + + fixtures_path = FIXTURES_PATH / "exposure" / "default" + output_path = TMP_PATH / "exposure" / "aliased" + core.extract_exposures( + output_path=str(output_path), + output_grouping=None, + ) + + _assert_exposures( + fixtures_path / "exposures.yml", + output_path / "exposures.yml", + ) diff --git a/tests/test_format.py b/tests/test_format.py index 1cf11bb..367d79b 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -1,91 +1,73 @@ -import unittest - from dbtmetabase.format import Filter, NullValue, dump_yaml, safe_description, safe_name +from tests._mocks import FIXTURES_PATH, TMP_PATH -from ._mocks import FIXTURES_PATH, TMP_PATH +def test_filter(): + assert Filter(include=("alpHa", "bRavo")).match("Alpha") + assert Filter().match("Alpha") + assert Filter().match("") + assert not Filter(include=("alpHa", "bRavo"), exclude=("alpha",)).match("Alpha") + assert not Filter(exclude=("alpha",)).match("Alpha") + assert Filter(include="alpha").match("Alpha") + assert not Filter(exclude="alpha").match("Alpha") -class TestFormat(unittest.TestCase): - def test_filter(self): - self.assertTrue( - Filter( - include=("alpHa", "bRavo"), - ).match("Alpha") - ) - self.assertTrue(Filter().match("Alpha")) - self.assertTrue(Filter().match("")) - self.assertFalse( - Filter( - include=("alpHa", "bRavo"), - exclude=("alpha",), - ).match("Alpha") - ) - self.assertFalse( - Filter( - exclude=("alpha",), - ).match("Alpha") - ) - self.assertTrue(Filter(include="alpha").match("Alpha")) - self.assertFalse(Filter(exclude="alpha").match("Alpha")) - def test_filter_wildcard(self): - self.assertTrue(Filter(include="stg_*").match("stg_orders")) - self.assertTrue(Filter(include="STG_*").match("stg_ORDERS")) - self.assertFalse(Filter(include="stg_*").match("orders")) - self.assertTrue(Filter(include="order?").match("orders")) - self.assertFalse(Filter(include="order?").match("ordersz")) - self.assertTrue(Filter(include="*orders", exclude="stg_*").match("_orders")) - self.assertFalse(Filter(include="*orders", exclude="stg_*").match("stg_orders")) +def test_filter_wildcard(): + assert Filter(include="stg_*").match("stg_orders") + assert Filter(include="STG_*").match("stg_ORDERS") + assert not Filter(include="stg_*").match("orders") + assert Filter(include="order?").match("orders") + assert not Filter(include="order?").match("ordersz") + assert Filter(include="*orders", exclude="stg_*").match("_orders") + assert not Filter(include="*orders", exclude="stg_*").match("stg_orders") - def test_null_value(self): - self.assertIsNotNone(NullValue) - self.assertFalse(NullValue) - self.assertIs(NullValue, NullValue) - def test_safe_name(self): - self.assertEqual( - "somebody_s_2_collections_", - safe_name("Somebody's 2 collections!"), - ) - self.assertEqual( - "somebody_s_2_collections_", - safe_name("somebody_s_2_collections_"), - ) - self.assertEqual("", safe_name("")) +def test_null_value(): + assert NullValue is not None + assert not NullValue + assert NullValue is NullValue + + +def test_safe_name(): + assert safe_name("Somebody's 2 collections!") == "somebody_s_2_collections_" + assert safe_name("somebody_s_2_collections_") == "somebody_s_2_collections_" + assert safe_name("") == "" - def test_safe_description(self): - self.assertEqual( - "Depends on\n\nQuestion ( #2 )!", - safe_description("Depends on\n\nQuestion {{ #2 }}!"), - ) - self.assertEqual( - "Depends on\n\nQuestion ( #2 )!", - safe_description("Depends on\n\nQuestion ( #2 )!"), - ) - self.assertEqual( - "Depends on\n\nQuestion { #2 }!", - safe_description("Depends on\n\nQuestion { #2 }!"), - ) - self.assertEqual( - "(start_date) - cast((rolling_days))", - safe_description("{{start_date}} - cast({{rolling_days}})"), - ) - def test_dump_yaml(self): - fixture_path = FIXTURES_PATH / "test_dump_yaml.yml" - output_path = TMP_PATH / "test_dump_yaml.yml" - with open(output_path, "w", encoding="utf-8") as f: - dump_yaml( - data={ - "root": { - "attr1": "val1\nend", - "attr2": ["val2", "val3"], - }, +def test_safe_description(): + assert ( + safe_description("Depends on\n\nQuestion {{ #2 }}!") + == "Depends on\n\nQuestion ( #2 )!" + ) + assert ( + safe_description("Depends on\n\nQuestion ( #2 )!") + == "Depends on\n\nQuestion ( #2 )!" + ) + assert ( + safe_description("Depends on\n\nQuestion { #2 }!") + == "Depends on\n\nQuestion { #2 }!" + ) + assert ( + safe_description("{{start_date}} - cast({{rolling_days}})") + == "(start_date) - cast((rolling_days))" + ) + + +def test_dump_yaml(): + fixture_path = FIXTURES_PATH / "test_dump_yaml.yml" + output_path = TMP_PATH / "test_dump_yaml.yml" + with open(output_path, "w", encoding="utf-8") as f: + dump_yaml( + data={ + "root": { + "attr1": "val1\nend", + "attr2": ["val2", "val3"], }, - stream=f, - ) - with open(output_path, "r", encoding="utf-8") as f: - actual = f.read() - with open(fixture_path, "r", encoding="utf-8") as f: - expected = f.read() - self.assertEqual(expected, actual) + }, + stream=f, + ) + with open(output_path, "r", encoding="utf-8") as f: + actual = f.read() + with open(fixture_path, "r", encoding="utf-8") as f: + expected = f.read() + assert actual == expected diff --git a/tests/test_manifest.py b/tests/test_manifest.py index a731d71..4a8e098 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -1,390 +1,379 @@ -import unittest from operator import attrgetter from typing import Sequence from dbtmetabase.manifest import Column, Group, Manifest, Model +from tests._mocks import FIXTURES_PATH, MockManifest -from ._mocks import FIXTURES_PATH, MockManifest +def test_v11_disabled(): + manifest = MockManifest(FIXTURES_PATH / "manifest-v11-disabled.json") + manifest.read_models() -class TestManifest(unittest.TestCase): - def test_v11_disabled(self): - manifest = MockManifest(FIXTURES_PATH / "manifest-v11-disabled.json") - manifest.read_models() + orders_mod = manifest.find_model("orders") + assert orders_mod is None - orders_mod = manifest.find_model("orders") - self.assertIsNone(orders_mod) + customer_id_col = manifest.find_column("customers", "customer_id") + assert customer_id_col is not None + assert customer_id_col.fk_target_table is None + assert customer_id_col.fk_target_field is None - customer_id_col = manifest.find_column("customers", "customer_id") - self.assertIsNotNone(customer_id_col) - self.assertIsNone(customer_id_col.fk_target_table) - self.assertIsNone(customer_id_col.fk_target_field) - def test_v12(self): - models = Manifest(FIXTURES_PATH / "manifest-v12.json").read_models() - self._assertModelsEqual( - models, - [ - Model( - database="dbtmetabase", - schema="public", - group=Group.nodes, - name="customers", - alias="customers", - description="This table has basic information about a customer, as well as some derived facts based on a customer's orders", - display_name="clients", - unique_id="model.sandbox.customers", - columns=[ - Column( - name="customer_id", - description="This is a unique identifier for a customer", - ), - Column( - name="first_name", - description="Customer's first name. PII.", - ), - Column( - name="last_name", - description="Customer's last name. PII.", - ), - Column( - name="first_order", - description="Date (UTC) of a customer's first order", - ), - Column( - name="most_recent_order", - description="Date (UTC) of a customer's most recent order", - ), - Column( - name="number_of_orders", - description="Count of the number of orders a customer has placed", - display_name="order_count", - ), - Column( - name="customer_lifetime_value", - description="Total value (AUD) of a customer's orders", - ), - ], - ), - Model( - database="dbtmetabase", - schema="public", - group=Group.nodes, - name="orders", - alias="orders", - description="This table has basic information about orders, as well as some derived facts based on payments", - points_of_interest="Basic information only", - caveats="Some facts are derived from payments", - unique_id="model.sandbox.orders", - columns=[ - Column( - name="order_id", - description="This is a unique identifier for an order", - semantic_type="type/PK", - ), - Column( - name="customer_id", - description="Foreign key to the customers table", - semantic_type="type/FK", - fk_target_table="public.customers", - fk_target_field="customer_id", - ), - Column( - name="order_date", - description="Date (UTC) that the order was placed", - ), - Column( - name="status", - description="", - ), - Column( - name="amount", - description="Total amount (AUD) of the order", - ), - Column( - name="credit_card_amount", - description="Amount of the order (AUD) paid for by credit card", - ), - Column( - name="coupon_amount", - description="Amount of the order (AUD) paid for by coupon", - ), - Column( - name="bank_transfer_amount", - description="Amount of the order (AUD) paid for by bank transfer", - ), - Column( - name="gift_card_amount", - description="Amount of the order (AUD) paid for by gift card", - ), - ], - ), - Model( - database="dbtmetabase", - schema="public", - group=Group.nodes, - name="stg_customers", - alias="stg_customers", - description="", - unique_id="model.sandbox.stg_customers", - columns=[ - Column( - name="customer_id", - description="", - ), - Column( - name="first_name", - description="", - ), - Column( - name="last_name", - description="", - ), - ], - ), - Model( - database="dbtmetabase", - schema="public", - group=Group.nodes, - name="stg_payments", - alias="stg_payments", - description="", - unique_id="model.sandbox.stg_payments", - columns=[ - Column( - name="payment_id", - description="", - ), - Column( - name="payment_method", - description="", - ), - Column( - name="order_id", - description="", - ), - Column( - name="amount", - description="", - ), - ], - ), - Model( - database="dbtmetabase", - schema="public", - group=Group.nodes, - name="stg_orders", - alias="stg_orders", - description="", - unique_id="model.sandbox.stg_orders", - columns=[ - Column( - name="order_id", - description="", - ), - Column( - name="status", - description="", - ), - Column( - name="order_date", - description="", - ), - Column( - name="customer_id", - description="", - ), - ], - ), - ], - ) +def test_v12(): + models = Manifest(FIXTURES_PATH / "manifest-v12.json").read_models() + _assert_models_equal( + models, + [ + Model( + database="dbtmetabase", + schema="public", + group=Group.nodes, + name="customers", + alias="customers", + description="This table has basic information about a customer, as well as some derived facts based on a customer's orders", + display_name="clients", + unique_id="model.sandbox.customers", + columns=[ + Column( + name="customer_id", + description="This is a unique identifier for a customer", + ), + Column( + name="first_name", + description="Customer's first name. PII.", + ), + Column( + name="last_name", + description="Customer's last name. PII.", + ), + Column( + name="first_order", + description="Date (UTC) of a customer's first order", + ), + Column( + name="most_recent_order", + description="Date (UTC) of a customer's most recent order", + ), + Column( + name="number_of_orders", + description="Count of the number of orders a customer has placed", + display_name="order_count", + ), + Column( + name="customer_lifetime_value", + description="Total value (AUD) of a customer's orders", + ), + ], + ), + Model( + database="dbtmetabase", + schema="public", + group=Group.nodes, + name="orders", + alias="orders", + description="This table has basic information about orders, as well as some derived facts based on payments", + points_of_interest="Basic information only", + caveats="Some facts are derived from payments", + unique_id="model.sandbox.orders", + columns=[ + Column( + name="order_id", + description="This is a unique identifier for an order", + semantic_type="type/PK", + ), + Column( + name="customer_id", + description="Foreign key to the customers table", + semantic_type="type/FK", + fk_target_table="public.customers", + fk_target_field="customer_id", + ), + Column( + name="order_date", + description="Date (UTC) that the order was placed", + ), + Column( + name="status", + description="", + ), + Column( + name="amount", + description="Total amount (AUD) of the order", + ), + Column( + name="credit_card_amount", + description="Amount of the order (AUD) paid for by credit card", + ), + Column( + name="coupon_amount", + description="Amount of the order (AUD) paid for by coupon", + ), + Column( + name="bank_transfer_amount", + description="Amount of the order (AUD) paid for by bank transfer", + ), + Column( + name="gift_card_amount", + description="Amount of the order (AUD) paid for by gift card", + ), + ], + ), + Model( + database="dbtmetabase", + schema="public", + group=Group.nodes, + name="stg_customers", + alias="stg_customers", + description="", + unique_id="model.sandbox.stg_customers", + columns=[ + Column( + name="customer_id", + description="", + ), + Column( + name="first_name", + description="", + ), + Column( + name="last_name", + description="", + ), + ], + ), + Model( + database="dbtmetabase", + schema="public", + group=Group.nodes, + name="stg_payments", + alias="stg_payments", + description="", + unique_id="model.sandbox.stg_payments", + columns=[ + Column( + name="payment_id", + description="", + ), + Column( + name="payment_method", + description="", + ), + Column( + name="order_id", + description="", + ), + Column( + name="amount", + description="", + ), + ], + ), + Model( + database="dbtmetabase", + schema="public", + group=Group.nodes, + name="stg_orders", + alias="stg_orders", + description="", + unique_id="model.sandbox.stg_orders", + columns=[ + Column( + name="order_id", + description="", + ), + Column( + name="status", + description="", + ), + Column( + name="order_date", + description="", + ), + Column( + name="customer_id", + description="", + ), + ], + ), + ], + ) - def test_v2(self): - models = Manifest(FIXTURES_PATH / "manifest-v2.json").read_models() - self._assertModelsEqual( - models, - [ - Model( - database="test", - schema="public", - group=Group.nodes, - name="orders", - alias="orders", - description="This table has basic information about orders, as well as some derived facts based on payments", - unique_id="model.jaffle_shop.orders", - columns=[ - Column( - name="order_id", - description="This is a unique identifier for an order", - ), - Column( - name="customer_id", - description="Foreign key to the customers table", - semantic_type="type/FK", - fk_target_table="public.customers", - fk_target_field="customer_id", - ), - Column( - name="order_date", - description="Date (UTC) that the order was placed", - ), - Column( - name="status", - description="Orders can be one of the following statuses:\n\n| status | description |\n|----------------|------------------------------------------------------------------------------------------------------------------------|\n| placed | The order has been placed but has not yet left the warehouse |\n| shipped | The order has ben shipped to the customer and is currently in transit |\n| completed | The order has been received by the customer |\n| return_pending | The customer has indicated that they would like to return the order, but it has not yet been received at the warehouse |\n| returned | The order has been returned by the customer and received at the warehouse |", - ), - Column( - name="amount", - description="Total amount (AUD) of the order", - ), - Column( - name="credit_card_amount", - description="Amount of the order (AUD) paid for by credit card", - ), - Column( - name="coupon_amount", - description="Amount of the order (AUD) paid for by coupon", - ), - Column( - name="bank_transfer_amount", - description="Amount of the order (AUD) paid for by bank transfer", - ), - Column( - name="gift_card_amount", - description="Amount of the order (AUD) paid for by gift card", - ), - ], - ), - Model( - database="test", - schema="public", - group=Group.nodes, - name="customers", - alias="customers", - description="This table has basic information about a customer, as well as some derived facts based on a customer's orders", - unique_id="model.jaffle_shop.customers", - columns=[ - Column( - name="customer_id", - description="This is a unique identifier for a customer", - semantic_type=None, # This is a PK field, should not be detected as FK - ), - Column( - name="first_name", - description="Customer's first name. PII.", - ), - Column( - name="last_name", - description="Customer's last name. PII.", - ), - Column( - name="first_order", - description="Date (UTC) of a customer's first order", - ), - Column( - name="most_recent_order", - description="Date (UTC) of a customer's most recent order", - ), - Column( - name="number_of_orders", - description="Count of the number of orders a customer has placed", - ), - Column( - name="customer_lifetime_value", - description="Total value (AUD) of a customer's orders", - ), - ], - ), - Model( - database="test", - schema="public", - group=Group.nodes, - name="stg_orders", - alias="stg_orders", - description="", - unique_id="model.jaffle_shop.stg_orders", - columns=[ - Column( - name="order_id", - description="", - ), - Column( - name="status", - description="", - ), - ], - ), - Model( - database="test", - schema="public", - group=Group.nodes, - name="stg_payments", - alias="stg_payments", - description="", - unique_id="model.jaffle_shop.stg_payments", - columns=[ - Column( - name="payment_id", - description="", - ), - Column( - name="payment_method", - description="", - ), - ], - ), - Model( - database="test", - schema="public", - group=Group.nodes, - name="stg_customers", - alias="stg_customers", - description="", - unique_id="model.jaffle_shop.stg_customers", - tags=[], - columns=[ - Column( - name="customer_id", - description="", - ) - ], - ), - ], - ) - def _assertModelsEqual( - self, - first: Sequence[Model], - second: Sequence[Model], - ): - self.assertEqual(len(first), len(second), "mismatched model count") +def test_v2(): + models = Manifest(FIXTURES_PATH / "manifest-v2.json").read_models() + _assert_models_equal( + models, + [ + Model( + database="test", + schema="public", + group=Group.nodes, + name="orders", + alias="orders", + description="This table has basic information about orders, as well as some derived facts based on payments", + unique_id="model.jaffle_shop.orders", + columns=[ + Column( + name="order_id", + description="This is a unique identifier for an order", + ), + Column( + name="customer_id", + description="Foreign key to the customers table", + semantic_type="type/FK", + fk_target_table="public.customers", + fk_target_field="customer_id", + ), + Column( + name="order_date", + description="Date (UTC) that the order was placed", + ), + Column( + name="status", + description="Orders can be one of the following statuses:\n\n| status | description |\n|----------------|------------------------------------------------------------------------------------------------------------------------|\n| placed | The order has been placed but has not yet left the warehouse |\n| shipped | The order has ben shipped to the customer and is currently in transit |\n| completed | The order has been received by the customer |\n| return_pending | The customer has indicated that they would like to return the order, but it has not yet been received at the warehouse |\n| returned | The order has been returned by the customer and received at the warehouse |", + ), + Column( + name="amount", + description="Total amount (AUD) of the order", + ), + Column( + name="credit_card_amount", + description="Amount of the order (AUD) paid for by credit card", + ), + Column( + name="coupon_amount", + description="Amount of the order (AUD) paid for by coupon", + ), + Column( + name="bank_transfer_amount", + description="Amount of the order (AUD) paid for by bank transfer", + ), + Column( + name="gift_card_amount", + description="Amount of the order (AUD) paid for by gift card", + ), + ], + ), + Model( + database="test", + schema="public", + group=Group.nodes, + name="customers", + alias="customers", + description="This table has basic information about a customer, as well as some derived facts based on a customer's orders", + unique_id="model.jaffle_shop.customers", + columns=[ + Column( + name="customer_id", + description="This is a unique identifier for a customer", + semantic_type=None, # This is a PK field, should not be detected as FK + ), + Column( + name="first_name", + description="Customer's first name. PII.", + ), + Column( + name="last_name", + description="Customer's last name. PII.", + ), + Column( + name="first_order", + description="Date (UTC) of a customer's first order", + ), + Column( + name="most_recent_order", + description="Date (UTC) of a customer's most recent order", + ), + Column( + name="number_of_orders", + description="Count of the number of orders a customer has placed", + ), + Column( + name="customer_lifetime_value", + description="Total value (AUD) of a customer's orders", + ), + ], + ), + Model( + database="test", + schema="public", + group=Group.nodes, + name="stg_orders", + alias="stg_orders", + description="", + unique_id="model.jaffle_shop.stg_orders", + columns=[ + Column( + name="order_id", + description="", + ), + Column( + name="status", + description="", + ), + ], + ), + Model( + database="test", + schema="public", + group=Group.nodes, + name="stg_payments", + alias="stg_payments", + description="", + unique_id="model.jaffle_shop.stg_payments", + columns=[ + Column( + name="payment_id", + description="", + ), + Column( + name="payment_method", + description="", + ), + ], + ), + Model( + database="test", + schema="public", + group=Group.nodes, + name="stg_customers", + alias="stg_customers", + description="", + unique_id="model.jaffle_shop.stg_customers", + tags=[], + columns=[ + Column( + name="customer_id", + description="", + ) + ], + ), + ], + ) - first = sorted(first, key=attrgetter("name")) - second = sorted(second, key=attrgetter("name")) - for i, first_model in enumerate(first): - second_model = second[i] - self.assertEqual(first_model.name, second_model.name, "wrong model") - self.assertEqual( - len(first_model.columns), - len(second_model.columns), - f"mismatched column count in {first_model.name}", - ) - for j, first_column in enumerate(first_model.columns): - second_column = second_model.columns[j] - self.assertEqual( - first_column.name, - second_column.name, - f"wrong column in model {first_model.name}", - ) - self.assertEqual( - first_column, - second_column, - f"mismatched column {first_model.name}.{first_column.name}", - ) - self.assertEqual( - first_model, - second_model, - f"mismatched model {first_model.name}", - ) +def _assert_models_equal( + first: Sequence[Model], + second: Sequence[Model], +): + assert len(first) == len(second), "mismatched model count" - self.assertEqual(first, second) + first = sorted(first, key=attrgetter("name")) + second = sorted(second, key=attrgetter("name")) + + for i, first_model in enumerate(first): + second_model = second[i] + assert first_model.name == second_model.name, "wrong model" + assert len(first_model.columns) == len( + second_model.columns + ), f"mismatched column count in {first_model.name}" + for j, first_column in enumerate(first_model.columns): + second_column = second_model.columns[j] + assert ( + first_column.name == second_column.name + ), f"wrong column in model {first_model.name}" + assert ( + first_column == second_column + ), f"mismatched column {first_model.name}.{first_column.name}" + assert first_model == second_model, f"mismatched model {first_model.name}" + + assert first == second diff --git a/tests/test_metabase.py b/tests/test_metabase.py index 65541b9..4b7c841 100644 --- a/tests/test_metabase.py +++ b/tests/test_metabase.py @@ -1,40 +1,43 @@ -import unittest - -from ._mocks import MockMetabase - - -class TestMetabase(unittest.TestCase): - def setUp(self): - self.metabase = MockMetabase(url="http://localhost") - - def test_metabase_find_database(self): - db = self.metabase.find_database(name="dbtmetabase") - assert db - self.assertEqual(2, db["id"]) - self.assertIsNone(self.metabase.find_database(name="foo")) - - def test_metabase_get_collections(self): - excluded = self.metabase.get_collections(exclude_personal=True) - self.assertEqual(1, len(excluded)) - - included = self.metabase.get_collections(exclude_personal=False) - self.assertEqual(2, len(included)) - - def test_metabase_get_collection_items(self): - cards = self.metabase.get_collection_items( - uid="root", - models=("card",), - ) - self.assertEqual({"card"}, {item["model"] for item in cards}) - - dashboards = self.metabase.get_collection_items( - uid="root", - models=("dashboard",), - ) - self.assertEqual({"dashboard"}, {item["model"] for item in dashboards}) - - both = self.metabase.get_collection_items( - uid="root", - models=("card", "dashboard"), - ) - self.assertEqual({"card", "dashboard"}, {item["model"] for item in both}) +import pytest + +from tests._mocks import MockMetabase + + +@pytest.fixture(name="metabase") +def fixture_metabase() -> MockMetabase: + return MockMetabase(url="http://localhost") + + +def test_metabase_find_database(metabase: MockMetabase): + db = metabase.find_database(name="dbtmetabase") + assert db + assert db["id"] == 2 + assert metabase.find_database(name="foo") is None + + +def test_metabase_get_collections(metabase: MockMetabase): + excluded = metabase.get_collections(exclude_personal=True) + assert len(excluded) == 1 + + included = metabase.get_collections(exclude_personal=False) + assert len(included) == 2 + + +def test_metabase_get_collection_items(metabase: MockMetabase): + cards = metabase.get_collection_items( + uid="root", + models=("card",), + ) + assert {item["model"] for item in cards} == {"card"} + + dashboards = metabase.get_collection_items( + uid="root", + models=("dashboard",), + ) + assert {item["model"] for item in dashboards} == {"dashboard"} + + both = metabase.get_collection_items( + uid="root", + models=("card", "dashboard"), + ) + assert {item["model"] for item in both} == {"card", "dashboard"} diff --git a/tests/test_models.py b/tests/test_models.py index e488a07..4d4db60 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,78 +1,87 @@ -import unittest +from typing import MutableSequence, cast -from ._mocks import MockDbtMetabase +import pytest +from dbtmetabase.manifest import Column +from tests._mocks import MockDbtMetabase -class TestModels(unittest.TestCase): - def setUp(self): - # pylint: disable=protected-access - self.c = MockDbtMetabase() - self.c._ModelsMixin__SYNC_PERIOD = 1 # type: ignore - def test_export(self): - self.c.export_models( - metabase_database="dbtmetabase", - skip_sources=True, - sync_timeout=1, - order_fields=True, - ) +@pytest.fixture(name="core") +def fixture_core() -> MockDbtMetabase: + # pylint: disable=protected-access + c = MockDbtMetabase() + c._ModelsMixin__SYNC_PERIOD = 1 # type: ignore + return c - def test_export_hidden_table(self): - # pylint: disable=protected-access - self.c._manifest.read_models() - model = self.c._manifest.find_model("stg_customers") - model.visibility_type = "hidden" - column = model.columns[0] - column.name = "new_column_since_stale" - model.columns.append(column) +def test_export(core: MockDbtMetabase): + core.export_models( + metabase_database="dbtmetabase", + skip_sources=True, + sync_timeout=1, + order_fields=True, + ) - self.c.export_models( - metabase_database="dbtmetabase", - skip_sources=True, - sync_timeout=1, - order_fields=True, - ) - def test_build_lookups(self): - # pylint: disable=protected-access,no-member - expected = { - "PUBLIC.CUSTOMERS": [ - "CUSTOMER_ID", - "FIRST_NAME", - "LAST_NAME", - "FIRST_ORDER", - "MOST_RECENT_ORDER", - "NUMBER_OF_ORDERS", - "CUSTOMER_LIFETIME_VALUE", - ], - "PUBLIC.ORDERS": [ - "ORDER_ID", - "CUSTOMER_ID", - "ORDER_DATE", - "STATUS", - "AMOUNT", - "CREDIT_CARD_AMOUNT", - "COUPON_AMOUNT", - "BANK_TRANSFER_AMOUNT", - "GIFT_CARD_AMOUNT", - ], - "PUBLIC.RAW_CUSTOMERS": ["ID", "FIRST_NAME", "LAST_NAME"], - "PUBLIC.RAW_ORDERS": ["ID", "USER_ID", "ORDER_DATE", "STATUS"], - "PUBLIC.RAW_PAYMENTS": ["ID", "ORDER_ID", "PAYMENT_METHOD", "AMOUNT"], - "PUBLIC.STG_CUSTOMERS": ["CUSTOMER_ID", "FIRST_NAME", "LAST_NAME"], - "PUBLIC.STG_ORDERS": ["ORDER_ID", "STATUS", "ORDER_DATE", "CUSTOMER_ID"], - "PUBLIC.STG_PAYMENTS": [ - "PAYMENT_ID", - "PAYMENT_METHOD", - "ORDER_ID", - "AMOUNT", - ], - } +def test_export_hidden_table(core: MockDbtMetabase): + # pylint: disable=protected-access + core._manifest.read_models() + model = core._manifest.find_model("stg_customers") + assert model is not None + model.visibility_type = "hidden" - actual_tables = self.c._ModelsMixin__get_tables(database_id="2") # type: ignore + column = model.columns[0] + column.name = "new_column_since_stale" + columns = cast(MutableSequence[Column], model.columns) + columns.append(column) - self.assertEqual(list(expected.keys()), list(actual_tables.keys())) + core.export_models( + metabase_database="dbtmetabase", + skip_sources=True, + sync_timeout=1, + order_fields=True, + ) - for table, columns in expected.items(): - self.assertEqual(columns, list(actual_tables[table]["fields"].keys())) + +def test_build_lookups(core: MockDbtMetabase): + # pylint: disable=protected-access,no-member + expected = { + "PUBLIC.CUSTOMERS": [ + "CUSTOMER_ID", + "FIRST_NAME", + "LAST_NAME", + "FIRST_ORDER", + "MOST_RECENT_ORDER", + "NUMBER_OF_ORDERS", + "CUSTOMER_LIFETIME_VALUE", + ], + "PUBLIC.ORDERS": [ + "ORDER_ID", + "CUSTOMER_ID", + "ORDER_DATE", + "STATUS", + "AMOUNT", + "CREDIT_CARD_AMOUNT", + "COUPON_AMOUNT", + "BANK_TRANSFER_AMOUNT", + "GIFT_CARD_AMOUNT", + ], + "PUBLIC.RAW_CUSTOMERS": ["ID", "FIRST_NAME", "LAST_NAME"], + "PUBLIC.RAW_ORDERS": ["ID", "USER_ID", "ORDER_DATE", "STATUS"], + "PUBLIC.RAW_PAYMENTS": ["ID", "ORDER_ID", "PAYMENT_METHOD", "AMOUNT"], + "PUBLIC.STG_CUSTOMERS": ["CUSTOMER_ID", "FIRST_NAME", "LAST_NAME"], + "PUBLIC.STG_ORDERS": ["ORDER_ID", "STATUS", "ORDER_DATE", "CUSTOMER_ID"], + "PUBLIC.STG_PAYMENTS": [ + "PAYMENT_ID", + "PAYMENT_METHOD", + "ORDER_ID", + "AMOUNT", + ], + } + + actual_tables = core._ModelsMixin__get_tables(database_id="2") # type: ignore + + assert list(actual_tables.keys()) == list(expected.keys()) + + for table, columns in expected.items(): + assert list(actual_tables[table]["fields"].keys()) == columns