Skip to content

Commit

Permalink
Move everything into pytest fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
JCZuurmond committed Dec 28, 2021
1 parent a9d717a commit fa668eb
Showing 1 changed file with 52 additions and 16 deletions.
68 changes: 52 additions & 16 deletions tests/test_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@
import os

import dbt.tracking
from dbt.adapters.factory import get_adapter, register_adapter
import pytest
from _pytest.fixtures import SubRequest
from dbt.adapters.factory import get_adapter, register_adapter, AdapterContainer
from dbt.clients.jinja import MacroGenerator
from dbt.config.runtime import RuntimeConfig
from dbt.context import providers
from dbt.contracts.connection import ConnectionState
from dbt.adapters.spark.connections import SparkConnectionManager, PyodbcConnectionWrapper
from dbt.contracts.graph.manifest import Manifest
from dbt.adapters.spark.connections import (
SparkConnectionManager,
PyodbcConnectionWrapper,
)
from dbt.parser.manifest import ManifestLoader
from dbt.tracking import User
from pyspark.sql import SparkSession
from sodaspark.scan import Connection


dbt.tracking.active_user = User(os.getcwd())


Expand All @@ -29,24 +37,52 @@ class Args:
project_dir: str = os.getcwd()


args = Args()
# Sets the Spark plugin in dbt.adapters.factory.FACTORY
config = RuntimeConfig.from_args(args)
@pytest.fixture
def config() -> RuntimeConfig:
config = RuntimeConfig.from_args(Args())
return config

register_adapter(config)

adapter = get_adapter(config)
@pytest.fixture
def adapter(config: RuntimeConfig) -> AdapterContainer:
register_adapter(config)
adapter = get_adapter(config)

connection_manager = _SparkConnectionManager(adapter.config)
adapter.connections = connection_manager
adapter.acquire_connection()
connection_manager = _SparkConnectionManager(adapter.config)
adapter.connections = connection_manager

manifest = ManifestLoader.get_full_manifest(config)
adapter.acquire_connection()

macro = manifest.macros["macro.spark_utils.get_tables"]
return adapter

context = providers.generate_runtime_macro_context(
macro, config, manifest, macro.package_name
)

result = MacroGenerator(macro, context)()
@pytest.fixture
def manifest(
adapter: AdapterContainer,
) -> Manifest:
manifest = ManifestLoader.get_full_manifest(adapter.config)
return manifest


@pytest.fixture
def macro_generator(
request: SubRequest, config: RuntimeConfig, manifest: Manifest
) -> MacroGenerator:
macro = manifest.macros[request.param]
context = providers.generate_runtime_macro_context(
macro, config, manifest, macro.package_name
)
macro_generator = MacroGenerator(macro, context)
return macro_generator


@pytest.mark.parametrize(
"macro_generator", ["macro.spark_utils.get_tables"], indirect=True
)
def test_create_table(
spark_session: SparkSession, macro_generator: MacroGenerator
) -> None:
expected_table = "default.example"
spark_session.sql(f"CREATE TABLE {expected_table} (id int) USING parquet")
tables = macro_generator()
assert tables == [expected_table]

0 comments on commit fa668eb

Please sign in to comment.