Skip to content

Commit

Permalink
Enable inline csv format in unit testing (#8743)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank authored Oct 5, 2023
1 parent 5cafb96 commit 3b6f9bd
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 25 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230928-163205.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Enable inline csv fixtures in unit tests
time: 2023-09-28T16:32:05.573776-04:00
custom:
Author: gshank
Issue: "8626"
5 changes: 5 additions & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,11 @@ class ModelConfig(NodeConfig):
)


@dataclass
class UnitTestNodeConfig(NodeConfig):
expected_rows: List[Dict[str, Any]] = field(default_factory=list)


@dataclass
class SeedConfig(NodeConfig):
materialized: str = "seed"
Expand Down
9 changes: 6 additions & 3 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
UnparsedSourceTableDefinition,
UnparsedColumn,
UnitTestOverrides,
InputFixture,
UnitTestInputFixture,
UnitTestOutputFixture,
)
from dbt.contracts.graph.node_args import ModelNodeArgs
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
Expand Down Expand Up @@ -78,6 +79,7 @@
SnapshotConfig,
SemanticModelConfig,
UnitTestConfig,
UnitTestNodeConfig,
)


Expand Down Expand Up @@ -1063,13 +1065,14 @@ class UnitTestNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Unit]})
attached_node: Optional[str] = None
overrides: Optional[UnitTestOverrides] = None
config: UnitTestNodeConfig = field(default_factory=UnitTestNodeConfig)


@dataclass
class UnitTestDefinition(GraphNode):
model: str
given: Sequence[InputFixture]
expect: List[Dict[str, Any]]
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
description: str = ""
overrides: Optional[UnitTestOverrides] = None
depends_on: DependsOn = field(default_factory=DependsOn)
Expand Down
53 changes: 49 additions & 4 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime
import re
import csv
from io import StringIO

from dbt import deprecations
from dbt.node_types import NodeType
Expand Down Expand Up @@ -736,10 +738,53 @@ def normalize_date(d: Optional[datetime.date]) -> Optional[datetime.datetime]:
return dt


class UnitTestFormat(StrEnum):
CSV = "csv"
Dict = "dict"


class UnitTestFixture:
@property
def format(self) -> UnitTestFormat:
return UnitTestFormat.Dict

@property
def rows(self) -> Union[str, List[Dict[str, Any]]]:
return []

def get_rows(self) -> List[Dict[str, Any]]:
if self.format == UnitTestFormat.Dict:
assert isinstance(self.rows, List)
return self.rows
elif self.format == UnitTestFormat.CSV:
assert isinstance(self.rows, str)
dummy_file = StringIO(self.rows)
reader = csv.DictReader(dummy_file)
rows = []
for row in reader:
rows.append(row)
return rows

def validate_fixture(self, fixture_type, test_name) -> None:
if (self.format == UnitTestFormat.Dict and not isinstance(self.rows, list)) or (
self.format == UnitTestFormat.CSV and not isinstance(self.rows, str)
):
raise ParsingError(
f"Unit test {test_name} has {fixture_type} rows which do not match format {self.format}"
)


@dataclass
class InputFixture(dbtClassMixin):
class UnitTestInputFixture(dbtClassMixin, UnitTestFixture):
input: str
rows: List[Dict[str, Any]] = field(default_factory=list)
rows: Union[str, List[Dict[str, Any]]] = ""
format: UnitTestFormat = UnitTestFormat.Dict


@dataclass
class UnitTestOutputFixture(dbtClassMixin, UnitTestFixture):
rows: Union[str, List[Dict[str, Any]]] = ""
format: UnitTestFormat = UnitTestFormat.Dict


@dataclass
Expand All @@ -752,8 +797,8 @@ class UnitTestOverrides(dbtClassMixin):
@dataclass
class UnparsedUnitTestDefinition(dbtClassMixin):
name: str
given: Sequence[InputFixture]
expect: List[Dict[str, Any]]
given: Sequence[UnitTestInputFixture]
expect: UnitTestOutputFixture
description: str = ""
overrides: Optional[UnitTestOverrides] = None
config: Dict[str, Any] = field(default_factory=dict)
Expand Down
16 changes: 11 additions & 5 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dbt.context.providers import generate_parse_exposure, get_rendered
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.model_config import NodeConfig
from dbt.contracts.graph.model_config import UnitTestNodeConfig, ModelConfig
from dbt.contracts.graph.nodes import (
ModelNode,
UnitTestNode,
Expand Down Expand Up @@ -66,7 +66,9 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
path=get_pseudo_test_path(name, test_case.original_file_path),
original_file_path=test_case.original_file_path,
unique_id=test_case.unique_id,
config=NodeConfig(materialized="unit", _extra={"expected_rows": test_case.expect}),
config=UnitTestNodeConfig(
materialized="unit", expected_rows=test_case.expect.get_rows()
),
raw_code=actual_node.raw_code,
database=actual_node.database,
schema=actual_node.schema,
Expand Down Expand Up @@ -118,16 +120,15 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
# TODO: package_name?
input_name = f"{test_case.model}__{test_case.name}__{original_input_node.name}"
input_unique_id = f"model.{package_name}.{input_name}"

input_node = ModelNode(
raw_code=self._build_raw_code(given.rows, original_input_node_columns),
raw_code=self._build_raw_code(given.get_rows(), original_input_node_columns),
resource_type=NodeType.Model,
package_name=package_name,
path=original_input_node.path,
original_file_path=original_input_node.original_file_path,
unique_id=input_unique_id,
name=input_name,
config=NodeConfig(materialized="ephemeral"),
config=ModelConfig(materialized="ephemeral"),
database=original_input_node.database,
schema=original_input_node.schema,
alias=original_input_node.alias,
Expand Down Expand Up @@ -189,6 +190,11 @@ def parse(self) -> ParseResult:
unit_test_fqn = [self.project.project_name] + model_name_split + [test.name]
unit_test_config = self._build_unit_test_config(unit_test_fqn, test.config)

# Check that format and type of rows matches for each given input
for input in test.given:
input.validate_fixture("input", test.name)
test.expect.validate_fixture("expected", test.name)

unit_test_definition = UnitTestDefinition(
name=test.name,
model=unit_test_suite.model,
Expand Down
7 changes: 7 additions & 0 deletions schemas/dbt/manifest/v11.json
Original file line number Diff line number Diff line change
Expand Up @@ -6011,6 +6011,13 @@
"type": "string"
}
}
},
"format": {
"enum": [
"csv",
"dict"
],
"default": "dict"
}
},
"additionalProperties": false,
Expand Down
Loading

0 comments on commit 3b6f9bd

Please sign in to comment.