diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 78602c74953..63c6f6e961d 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.14.0 +current_version = 0.15.0a1 parse = (?P\d+) \.(?P\d+) \.(?P\d+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1a5ce78037a..3dd08d13da6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -19,7 +19,7 @@ jobs: PGUSER: root PGPASSWORD: password PGDATABASE: postgres - - run: tox -e flake8,unit-py36 + - run: tox -e flake8,mypy,unit-py36 integration-postgres-py36: docker: *test_and_postgres steps: diff --git a/.gitignore b/.gitignore index 9b97c91563b..d1025d4ef4c 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ var/ *.egg-info/ .installed.cfg *.egg +*.mypy_cache/ logs/ # PyInstaller @@ -79,3 +80,6 @@ target/ # Vim *.sw* + +# pycharm +.idea/ diff --git a/core/dbt/__init__.py b/core/dbt/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/core/dbt/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/core/dbt/adapters/__init__.py b/core/dbt/adapters/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/core/dbt/adapters/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index fecbf55bf36..d0d5ae354bb 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -22,6 +22,8 @@ class Credentials( Replaceable, metaclass=abc.ABCMeta ): + database: str + schema: str _ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False) @abc.abstractproperty diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index cafc7b51675..d1621bc0dc8 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -8,6 +8,7 @@ import dbt.flags import dbt.clients.agate_helper +from dbt.contracts.graph.manifest import Manifest from dbt.node_types import NodeType from dbt.loader import GraphLoader from dbt.logger import GLOBAL_LOGGER as logger @@ -252,8 +253,7 @@ def type(cls): @property def _internal_manifest(self): if self._internal_manifest_lazy is None: - manifest = GraphLoader.load_internal(self.config) - self._internal_manifest_lazy = manifest + self.load_internal_manifest() return self._internal_manifest_lazy def check_internal_manifest(self): @@ -262,6 +262,12 @@ def check_internal_manifest(self): """ return self._internal_manifest_lazy + def load_internal_manifest(self) -> Manifest: + if self._internal_manifest_lazy is None: + manifest = GraphLoader.load_internal(self.config) + self._internal_manifest_lazy = manifest + return self._internal_manifest_lazy + ### # Caching methods ### diff --git a/core/dbt/clients/jinja.py b/core/dbt/clients/jinja.py index e4fb6b0e932..fbc654bc56f 100644 --- a/core/dbt/clients/jinja.py +++ b/core/dbt/clients/jinja.py @@ -2,6 +2,7 @@ import linecache import os import tempfile +from typing import List, Union, Set, Optional import jinja2 import jinja2._compat @@ -13,7 +14,7 @@ import dbt.exceptions import dbt.utils -from dbt.clients._jinja_blocks import BlockIterator +from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag from dbt.logger import GLOBAL_LOGGER as logger # noqa @@ -305,21 +306,24 @@ def undefined_error(msg): raise jinja2.exceptions.UndefinedError(msg) -def extract_toplevel_blocks(data, allowed_blocks=None, collect_raw_data=True): +def extract_toplevel_blocks( + data: str, + allowed_blocks: Optional[Set[str]] = None, + collect_raw_data: bool = True, +) -> List[Union[BlockData, BlockTag]]: """Extract the top level blocks with matching block types from a jinja file, with some special handling for block nesting. - :param str data: The data to extract blocks from. - :param Optional[Set[str]] allowed_blocks: The names of the blocks to - extract from the file. They may not be nested within if/for blocks. - If None, use the default values. - :param bool collect_raw_data: If set, raw data between matched blocks will - also be part of the results, as `BlockData` objects. They have a + :param data: The data to extract blocks from. + :param allowed_blocks: The names of the blocks to extract from the file. + They may not be nested within if/for blocks. If None, use the default + values. + :param collect_raw_data: If set, raw data between matched blocks will also + be part of the results, as `BlockData` objects. They have a `block_type_name` field of `'__dbt_data'` and will never have a `block_name`. - :return List[Union[BlockData, BlockTag]]: A list of `BlockTag`s matching - the allowed block types and (if `collect_raw_data` is `True`) - `BlockData` objects. + :return: A list of `BlockTag`s matching the allowed block types and (if + `collect_raw_data` is `True`) `BlockData` objects. """ return BlockIterator(data).lex_for_blocks( allowed_blocks=allowed_blocks, diff --git a/core/dbt/clients/system.py b/core/dbt/clients/system.py index 12f8790d3a5..9023ed85299 100644 --- a/core/dbt/clients/system.py +++ b/core/dbt/clients/system.py @@ -24,13 +24,13 @@ def find_matching(root_path, absolute root path (`relative_paths_to_search`), and a `file_pattern` like '*.sql', returns information about the files. For example: - > find_matching('/root/path', 'models', '*.sql') + > find_matching('/root/path', ['models'], '*.sql') [ { 'absolute_path': '/root/path/models/model_one.sql', - 'relative_path': 'models/model_one.sql', + 'relative_path': 'model_one.sql', 'searched_path': 'models' }, { 'absolute_path': '/root/path/models/subdirectory/model_two.sql', - 'relative_path': 'models/subdirectory/model_two.sql', + 'relative_path': 'subdirectory/model_two.sql', 'searched_path': 'models' } ] """ matching = [] diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 6f0aeae198f..8de231f19ab 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -94,7 +94,7 @@ def recursively_prepend_ctes(model, manifest): model.prepend_ctes(prepended_ctes) - manifest.nodes[model.unique_id] = model + manifest.update_node(model) return (model, prepended_ctes, manifest) @@ -167,7 +167,8 @@ def compile_node(self, node, manifest, extra_context=None): def write_graph_file(self, linker, manifest): filename = graph_file_name graph_path = os.path.join(self.config.target_path, filename) - linker.write_graph(graph_path, manifest) + if dbt.flags.WRITE_JSON: + linker.write_graph(graph_path, manifest) def link_node(self, linker, node, manifest): linker.add_node(node.unique_id) diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index 0c911c2aaec..9757ec7aaaf 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -18,7 +18,7 @@ from dbt.ui import printer from dbt.utils import deep_map from dbt.utils import parse_cli_vars -from dbt.parser.source_config import SourceConfig +from dbt.source_config import SourceConfig from dbt.contracts.project import Project as ProjectContract from dbt.contracts.project import PackageConfig diff --git a/core/dbt/context/runtime.py b/core/dbt/context/runtime.py index 065a99a1c58..6ad78c89136 100644 --- a/core/dbt/context/runtime.py +++ b/core/dbt/context/runtime.py @@ -3,7 +3,7 @@ import dbt.clients.jinja import dbt.context.common import dbt.flags -from dbt.parser import ParserUtils +from dbt.parser.util import ParserUtils from dbt.logger import GLOBAL_LOGGER as logger # noqa diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index a6a4eb19e5d..2dc9a700ab9 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -11,6 +11,7 @@ ParsedSourceDefinition, ParsedTestNode, TestConfig, + PARSED_TYPES, ) from dbt.node_types import ( NodeType, @@ -96,6 +97,11 @@ class CompiledRPCNode(CompiledNode): class CompiledSeedNode(CompiledNode): resource_type: SeedType + @property + def empty(self): + """ Seeds are never empty""" + return False + @dataclass class CompiledSnapshotNode(CompiledNode): @@ -187,6 +193,17 @@ def compiled_type_for(parsed: ParsedNode): return type(parsed) +def parsed_instance_for(compiled: CompiledNode) -> ParsedNode: + cls = PARSED_TYPES.get(compiled.resource_type) + if cls is None: + # how??? + raise ValueError('invalid resource_type: {}' + .format(compiled.resource_type)) + + # validate=False to allow extra keys from copmiling + return cls.from_dict(compiled.to_dict(), validate=False) + + # We allow either parsed or compiled nodes, or parsed sources, as some # 'compile()' calls in the runner actually just return the original parsed # node they were given. diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 60344be5a12..652865ed006 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -1,23 +1,152 @@ +import hashlib +import os +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, List, Optional, Union, Mapping +from uuid import UUID + +from hologram import JsonSchemaMixin + from dbt.contracts.graph.parsed import ParsedNode, ParsedMacro, \ ParsedDocumentation from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.util import Writable, Replaceable from dbt.config import Project -from dbt.exceptions import raise_duplicate_resource_name -from dbt.node_types import NodeType +from dbt.exceptions import raise_duplicate_resource_name, InternalException from dbt.logger import GLOBAL_LOGGER as logger +from dbt.node_types import NodeType from dbt import tracking import dbt.utils -from hologram import JsonSchemaMixin +NodeEdgeMap = Dict[str, List[str]] -from dataclasses import dataclass, field -from datetime import datetime -from typing import Dict, List, Optional -from uuid import UUID +@dataclass +class FilePath(JsonSchemaMixin): + searched_path: str + relative_path: str + absolute_path: str + + @property + def search_key(self): + # TODO: should this be project root + original_file_path? + return self.absolute_path -NodeEdgeMap = Dict[str, List[str]] + @property + def original_file_path(self): + return os.path.join(self.searched_path, self.relative_path) + + +@dataclass +class FileHash(JsonSchemaMixin): + name: str # the hash type name + checksum: str # the hashlib.hash_type().hexdigest() of the file contents + + @classmethod + def empty(cls): + return FileHash(name='none', checksum='') + + @classmethod + def path(cls, path: str): + return FileHash(name='path', checksum=path) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + + if self.name == 'none' or self.name != other.name: + return False + + return self.checksum == other.checksum + + def compare(self, contents: str) -> bool: + """Compare the file contents with the given hash""" + if self.name == 'none': + return False + + return self.from_contents(contents, name=self.name) == self.checksum + + @classmethod + def from_contents(cls, contents: str, name='sha256'): + """Create a file hash from the given file contents. The hash is always + the utf-8 encoding of the contents given, because dbt only reads files + as utf-8. + """ + data = contents.encode('utf-8') + checksum = hashlib.new(name, data).hexdigest() + return cls(name=name, checksum=checksum) + + +@dataclass +class RemoteFile(JsonSchemaMixin): + @property + def searched_path(self) -> str: + return 'from remote system' + + @property + def relative_path(self) -> str: + return 'from remote system' + + @property + def absolute_path(self) -> str: + return 'from remote system' + + @property + def original_file_path(self): + return 'from remote system' + + +@dataclass +class SourceFile(JsonSchemaMixin): + """Define a source file in dbt""" + path: Union[FilePath, RemoteFile] # the path information + checksum: FileHash + # we don't want to serialize this + _contents: Optional[str] = None + # the unique IDs contained in this file + nodes: List[str] = field(default_factory=list) + docs: List[str] = field(default_factory=list) + macros: List[str] = field(default_factory=list) + sources: List[str] = field(default_factory=list) + # any node patches in this file. The entries are names, not unique ids! + patches: List[str] = field(default_factory=list) + + @property + def search_key(self) -> Optional[str]: + if isinstance(self.path, RemoteFile): + return None + if self.checksum.name == 'none': + return None + return self.path.search_key + + @property + def contents(self) -> str: + if self._contents is None: + raise InternalException('SourceFile has no contents!') + return self._contents + + @contents.setter + def contents(self, value): + self._contents = value + + @classmethod + def empty(cls, path: FilePath) -> 'SourceFile': + self = cls(path=path, checksum=FileHash.empty()) + self.contents = '' + return self + + @classmethod + def seed(cls, path: FilePath) -> 'SourceFile': + """Seeds always parse the same regardless of their content.""" + self = cls(path=path, checksum=FileHash.path(path.absolute_path)) + self.contents = '' + return self + + @classmethod + def remote(cls, contents: str) -> 'SourceFile': + self = cls(path=RemoteFile(), checksum=FileHash.empty()) + self.contents = contents + return self @dataclass @@ -57,21 +186,23 @@ def _deepcopy(value): class Manifest: """The manifest for the full graph, after parsing and during compilation. """ - nodes: Dict[str, CompileResultNode] - macros: Dict[str, ParsedMacro] - docs: Dict[str, ParsedDocumentation] + nodes: Mapping[str, CompileResultNode] + macros: Mapping[str, ParsedMacro] + docs: Mapping[str, ParsedDocumentation] generated_at: datetime disabled: List[ParsedNode] + files: Mapping[str, SourceFile] metadata: ManifestMetadata = field(init=False) def __init__( self, - nodes: Dict[str, CompileResultNode], - macros: Dict[str, ParsedMacro], - docs: Dict[str, ParsedDocumentation], + nodes: Mapping[str, CompileResultNode], + macros: Mapping[str, ParsedMacro], + docs: Mapping[str, ParsedDocumentation], generated_at: datetime, disabled: List[ParsedNode], - config: Optional[Project] = None + files: Mapping[str, SourceFile], + config: Optional[Project] = None, ) -> None: self.metadata = self.get_metadata(config) self.nodes = nodes @@ -79,9 +210,40 @@ def __init__( self.docs = docs self.generated_at = generated_at self.disabled = disabled + self.files = files self._flat_graph = None super(Manifest, self).__init__() + @classmethod + def from_macros(cls, macros=None, files=None) -> 'Manifest': + if macros is None: + macros = {} + if files is None: + files = {} + return cls( + nodes={}, + macros=macros, + docs={}, + generated_at=datetime.utcnow(), + disabled=[], + files=files, + config=None, + ) + + def update_node(self, new_node): + unique_id = new_node.unique_id + if unique_id not in self.nodes: + raise dbt.exceptions.RuntimeException( + 'got an update_node call with an unrecognized node: {}' + .format(unique_id) + ) + existing = self.nodes[unique_id] + if new_node.original_file_path != existing.original_file_path: + raise dbt.exceptions.RuntimeException( + 'cannot update a node to have a new file path!' + ) + self.nodes[unique_id] = new_node + @staticmethod def get_metadata(config: Optional[Project]) -> ManifestMetadata: project_id = None @@ -101,23 +263,6 @@ def get_metadata(config: Optional[Project]) -> ManifestMetadata: send_anonymous_usage_stats=send_anonymous_usage_stats, ) - def serialize(self): - """Convert the parsed manifest to a nested dict structure that we can - safely serialize to JSON. - """ - forward_edges, backward_edges = build_edges(self.nodes.values()) - - return { - 'nodes': {k: v.to_dict() for k, v in self.nodes.items()}, - 'macros': {k: v.to_dict() for k, v in self.macros.items()}, - 'docs': {k: v.to_dict() for k, v in self.docs.items()}, - 'parent_map': backward_edges, - 'child_map': forward_edges, - 'generated_at': self.generated_at, - 'metadata': self.metadata, - 'disabled': [v.to_dict() for v in self.disabled], - } - def to_flat_graph(self): """This function gets called in context.common by each node, so we want to cache it. Make sure you don't call this until you're done with @@ -138,7 +283,6 @@ def find_disabled_by_name(self, name, package=None): def _find_by_name(self, name, package, subgraph, nodetype): """ - Find a node by its given name in the appropriate sugraph. If package is None, all pacakges will be searched. nodetype should be a list of NodeTypes to accept. @@ -265,6 +409,8 @@ def add_nodes(self, new_nodes): def patch_nodes(self, patches): """Patch nodes with the given dict of patches. Note that this consumes the input! + This relies on the fact that all nodes have unique _name_ fields, not + just unique unique_id fields. """ # because we don't have any mapping from node _names_ to nodes, and we # only have the node name in the patch, we have to iterate over all the @@ -310,11 +456,13 @@ def deepcopy(self, config=None): docs={k: _deepcopy(v) for k, v in self.docs.items()}, generated_at=self.generated_at, disabled=[_deepcopy(n) for n in self.disabled], - config=config + config=config, + files={k: _deepcopy(v) for k, v in self.files.items()}, ) def writable_manifest(self): forward_edges, backward_edges = build_edges(self.nodes.values()) + return WritableManifest( nodes=self.nodes, macros=self.macros, @@ -323,7 +471,8 @@ def writable_manifest(self): metadata=self.metadata, disabled=self.disabled, child_map=forward_edges, - parent_map=backward_edges + parent_map=backward_edges, + files=self.files, ) @classmethod @@ -335,6 +484,7 @@ def from_writable_manifest(cls, writable): generated_at=writable.generated_at, metadata=writable.metadata, disabled=writable.disabled, + files=writable.files, ) self.metadata = writable.metadata return self @@ -355,11 +505,13 @@ def write(self, path): @dataclass class WritableManifest(JsonSchemaMixin, Writable): - nodes: Dict[str, CompileResultNode] - macros: Dict[str, ParsedMacro] - docs: Dict[str, ParsedDocumentation] + nodes: Mapping[str, CompileResultNode] + macros: Mapping[str, ParsedMacro] + docs: Mapping[str, ParsedDocumentation] disabled: Optional[List[ParsedNode]] generated_at: datetime parent_map: Optional[NodeEdgeMap] child_map: Optional[NodeEdgeMap] metadata: ManifestMetadata + # map of original_file_path to all unique IDs provided by that file + files: Mapping[str, SourceFile] diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index f69b7ca47f4..50a7cb123b8 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -5,6 +5,7 @@ from hologram.helpers import StrEnum, NewPatternType, ExtensibleJsonSchemaMixin import dbt.clients.jinja +import dbt.flags from dbt.contracts.graph.unparsed import ( UnparsedNode, UnparsedMacro, UnparsedDocumentationFile, Quoting, UnparsedBaseNode, FreshnessThreshold @@ -61,10 +62,6 @@ class NodeConfig(ExtensibleJsonSchemaMixin, Replaceable): tags: Union[List[str], str] = field(default_factory=list) _extra: Dict[str, Any] = field(default_factory=dict) - def __post_init__(self): - if isinstance(self.tags, str): - self.tags = [self.tags] - @property def extra(self): return self._extra @@ -157,8 +154,8 @@ def patch(self, patch): self.description = patch.description self.columns = patch.columns self.docrefs = patch.docrefs - # patches should always trigger re-validation - self.to_dict(validate=True) + if dbt.flags.STRICT_MODE: + self.to_dict(validate=True) def get_materialization(self): return self.config.materialized @@ -173,6 +170,7 @@ class ParsedNodeMandatory( HasUniqueID, HasFqn, HasRelationMetadata, + Replaceable ): alias: str @@ -181,7 +179,7 @@ class ParsedNodeMandatory( class ParsedNodeDefaults(ParsedNodeMandatory): config: NodeConfig = field(default_factory=NodeConfig) tags: List[str] = field(default_factory=list) - refs: List[List[Any]] = field(default_factory=list) + refs: List[List[str]] = field(default_factory=list) sources: List[List[Any]] = field(default_factory=list) depends_on: DependsOn = field(default_factory=DependsOn) docrefs: List[Docref] = field(default_factory=list) @@ -221,6 +219,11 @@ class ParsedRPCNode(ParsedNode): class ParsedSeedNode(ParsedNode): resource_type: SeedType + @property + def empty(self): + """ Seeds are never empty""" + return False + @dataclass class TestConfig(NodeConfig): @@ -238,13 +241,13 @@ class ParsedTestNode(ParsedNode): class _SnapshotConfig(NodeConfig): unique_key: str target_schema: str - target_database: str + target_database: Optional[str] = None def __init__( self, unique_key: str, - target_database: str, target_schema: str, + target_database: Optional[str] = None, **kwargs ) -> None: self.target_database = target_database @@ -340,8 +343,8 @@ class ParsedSnapshotNode(ParsedNode): ] @classmethod - def json_schema(cls): - schema = super().json_schema() + def json_schema(cls, embeddable=False): + schema = super().json_schema(embeddable) # mess with config configs = [ @@ -349,7 +352,11 @@ def json_schema(cls): (str(TimestampStrategy.Timestamp), TimestampSnapshotConfig), ] - schema['properties']['config'] = _create_if_else_chain( + if embeddable: + dest = schema[cls.__name__]['properties'] + else: + dest = schema['properties'] + dest['config'] = _create_if_else_chain( 'strategy', configs, GenericSnapshotConfig ) return schema @@ -373,12 +380,13 @@ class MacroDependsOn(JsonSchemaMixin, Replaceable): @dataclass -class ParsedMacro(UnparsedMacro): +class ParsedMacro(UnparsedMacro, HasUniqueID): name: str resource_type: MacroType - unique_id: str - tags: List[str] - depends_on: MacroDependsOn + # TODO: can macros even have tags? + tags: List[str] = field(default_factory=list) + # TODO: is this ever populated? + depends_on: MacroDependsOn = field(default_factory=MacroDependsOn) def local_vars(self): return {} @@ -392,9 +400,8 @@ def generator(self): @dataclass -class ParsedDocumentation(UnparsedDocumentationFile): +class ParsedDocumentation(UnparsedDocumentationFile, HasUniqueID): name: str - unique_id: str block_contents: str diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index 3f44c569b56..96cb249fdef 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -88,7 +88,6 @@ class SourceFreshnessResult(JsonSchemaMixin, Writable): age: Real status: FreshnessStatus error: Optional[str] = None - status: Union[None, str, int, bool] = None execution_time: Union[str, int] = 0 thread_id: Optional[int] = 0 timing: List[TimingInfo] = field(default_factory=list) diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index f68ff084255..d8521292166 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -1,7 +1,9 @@ import builtins import functools +from typing import NoReturn from dbt.logger import GLOBAL_LOGGER as logger +from dbt.node_types import NodeType import dbt.flags import hologram @@ -298,15 +300,15 @@ def __str__(self): return '{} running: {}'.format(self.msg, self.cmd) -def raise_compiler_error(msg, node=None): +def raise_compiler_error(msg, node=None) -> NoReturn: raise CompilationException(msg, node) -def raise_database_error(msg, node=None): +def raise_database_error(msg, node=None) -> NoReturn: raise DatabaseException(msg, node) -def raise_dependency_error(msg): +def raise_dependency_error(msg) -> NoReturn: raise DependencyException(msg) @@ -344,9 +346,15 @@ def ref_bad_context(model, args): # better error messages. Ex. If models foo_users and bar_users are aliased # to 'users', in their respective schemas, then you would want to see # 'bar_users' in your error messge instead of just 'users'. + if isinstance(model, dict): # TODO: remove this path + model_name = model['name'] + model_path = model['path'] + else: + model_name = model.name + model_path = model.path error_msg = base_error_msg.format( - model_name=model['name'], - model_path=model['path'], + model_name=model_name, + model_path=model_path, ref_string=ref_string ) raise_compiler_error(error_msg, model) @@ -578,13 +586,20 @@ def approximate_relation_match(target, relation): def raise_duplicate_resource_name(node_1, node_2): duped_name = node_1.name + if node_1.resource_type in NodeType.refable(): + get_func = 'ref("{}")'.format(duped_name) + elif node_1.resource_type == NodeType.Source: + get_func = 'source("{}", "{}")'.format(node_1.source_name, duped_name) + elif node_1.resource_type == NodeType.Test and 'schema' in node_1.tags: + return + raise_compiler_error( 'dbt found two resources with the name "{}". Since these resources ' 'have the same name,\ndbt will be unable to find the correct resource ' - 'when ref("{}") is used. To fix this,\nchange the name of one of ' + 'when {} is used. To fix this,\nchange the name of one of ' 'these resources:\n- {} ({})\n- {} ({})'.format( duped_name, - duped_name, + get_func, node_1.unique_id, node_1.original_file_path, node_2.unique_id, node_2.original_file_path)) @@ -635,12 +650,14 @@ def raise_patch_targets_not_found(patches): def raise_duplicate_patch_name(name, patch_1, patch_2): raise_compiler_error( - 'dbt found two schema.yml entries for the same model named {}. The ' - 'first patch was specified in {} and the second in {}. Models and ' - 'their associated columns may only be described a single time.'.format( + 'dbt found two schema.yml entries for the same model named {0}. ' + 'Models and their associated columns may only be described a single ' + 'time. To fix this, remove the model entry for for {0} in one of ' + 'these files:\n - {1}\n - {2}' + .format( name, - patch_1, - patch_2, + patch_1.original_file_path, + patch_2.original_file_path, ) ) diff --git a/core/dbt/flags.py b/core/dbt/flags.py index 9b27ca4207c..9ddfb6880c2 100644 --- a/core/dbt/flags.py +++ b/core/dbt/flags.py @@ -1,29 +1,43 @@ -STRICT_MODE = False -FULL_REFRESH = False -USE_CACHE = True -WARN_ERROR = False -TEST_NEW_PARSER = False +# initially all flags are set to None, the on-load call of reset() will set +# them for their first time. +STRICT_MODE = None +FULL_REFRESH = None +USE_CACHE = None +WARN_ERROR = None +TEST_NEW_PARSER = None +WRITE_JSON = None +PARTIAL_PARSE = None def reset(): - global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER + global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \ + WRITE_JSON, PARTIAL_PARSE STRICT_MODE = False FULL_REFRESH = False USE_CACHE = True WARN_ERROR = False TEST_NEW_PARSER = False + WRITE_JSON = True + PARTIAL_PARSE = False def set_from_args(args): - global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER - USE_CACHE = getattr(args, 'use_cache', True) + global STRICT_MODE, FULL_REFRESH, USE_CACHE, WARN_ERROR, TEST_NEW_PARSER, \ + WRITE_JSON, PARTIAL_PARSE + USE_CACHE = getattr(args, 'use_cache', USE_CACHE) - FULL_REFRESH = getattr(args, 'full_refresh', False) - STRICT_MODE = getattr(args, 'strict', False) + FULL_REFRESH = getattr(args, 'full_refresh', FULL_REFRESH) + STRICT_MODE = getattr(args, 'strict', STRICT_MODE) WARN_ERROR = ( STRICT_MODE or - getattr(args, 'warn_error', False) + getattr(args, 'warn_error', STRICT_MODE or WARN_ERROR) ) - TEST_NEW_PARSER = getattr(args, 'test_new_parser', False) + TEST_NEW_PARSER = getattr(args, 'test_new_parser', TEST_NEW_PARSER) + WRITE_JSON = getattr(args, 'write_json', WRITE_JSON) + PARTIAL_PARSE = getattr(args, 'partial_parse', PARTIAL_PARSE) + + +# initialize everything to the defaults on module load +reset() diff --git a/core/dbt/helper_types.py b/core/dbt/helper_types.py new file mode 100644 index 00000000000..cc15e3d18bd --- /dev/null +++ b/core/dbt/helper_types.py @@ -0,0 +1,15 @@ +# never name this package "types", or mypy will crash in ugly ways +from hologram import FieldEncoder, JsonSchemaMixin +from typing import NewType + + +Port = NewType('Port', int) + + +class PortEncoder(FieldEncoder): + @property + def json_schema(self): + return {'type': 'integer', 'minimum': 0, 'maximum': 65535} + + +JsonSchemaMixin.register_field_encoders({Port: PortEncoder()}) diff --git a/core/dbt/include/__init__.py b/core/dbt/include/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/core/dbt/include/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/core/dbt/include/global_project/macros/materializations/snapshot/snapshot.sql b/core/dbt/include/global_project/macros/materializations/snapshot/snapshot.sql index a52e302d814..51bbfd4ecee 100644 --- a/core/dbt/include/global_project/macros/materializations/snapshot/snapshot.sql +++ b/core/dbt/include/global_project/macros/materializations/snapshot/snapshot.sql @@ -180,20 +180,18 @@ {% materialization snapshot, default %} {%- set config = model['config'] -%} - {%- set target_database = config.get('target_database') -%} - {%- set target_schema = config.get('target_schema') -%} {%- set target_table = model.get('alias', model.get('name')) -%} {%- set strategy_name = config.get('strategy') -%} {%- set unique_key = config.get('unique_key') %} - {% if not adapter.check_schema_exists(target_database, target_schema) %} - {% do create_schema(target_database, target_schema) %} + {% if not adapter.check_schema_exists(model.database, model.schema) %} + {% do create_schema(model.database, model.schema) %} {% endif %} {% set target_relation_exists, target_relation = get_or_create_relation( - database=target_database, - schema=target_schema, + database=model.database, + schema=model.schema, identifier=target_table, type='table') -%} diff --git a/core/dbt/loader.py b/core/dbt/loader.py index 48488c2d02d..c686617bf48 100644 --- a/core/dbt/loader.py +++ b/core/dbt/loader.py @@ -1,177 +1,302 @@ -import os import itertools +import os +import pickle +from datetime import datetime +from typing import Dict, Optional, Mapping from dbt.include.global_project import PACKAGES import dbt.exceptions import dbt.flags +from dbt.logger import GLOBAL_LOGGER as logger from dbt.node_types import NodeType -from dbt.contracts.graph.manifest import Manifest - -from dbt.parser import MacroParser, ModelParser, SeedParser, AnalysisParser, \ - DocumentationParser, DataTestParser, HookParser, SchemaParser, \ - ParserUtils, SnapshotParser - -from datetime import datetime +from dbt.clients.system import make_directory +from dbt.config import Project, RuntimeConfig +from dbt.contracts.graph.compiled import CompileResultNode +from dbt.contracts.graph.manifest import Manifest, FilePath, FileHash +from dbt.parser.base import BaseParser +from dbt.parser import AnalysisParser +from dbt.parser import DataTestParser +from dbt.parser import DocumentationParser +from dbt.parser import HookParser +from dbt.parser import MacroParser +from dbt.parser import ModelParser +from dbt.parser import ParseResult +from dbt.parser import SchemaParser +from dbt.parser import SeedParser +from dbt.parser import SnapshotParser +from dbt.parser import ParserUtils +from dbt.parser.search import FileBlock +from dbt.version import __version__ + + +PARTIAL_PARSE_FILE_NAME = 'partial_parse.pickle' + + +_parser_types = [ + ModelParser, + SnapshotParser, + AnalysisParser, + DataTestParser, + HookParser, + SeedParser, + DocumentationParser, + SchemaParser, +] + + +# TODO: this should be calculated per-file based on the vars() calls made in +# parsing, so changing one var doesn't invalidate everything. also there should +# be something like that for env_var - currently changing env_vars in way that +# impact graph selection or configs will result in weird test failures. +# finally, we should hash the actual profile used, not just root project + +# profiles.yml + relevant args. While sufficient, it is definitely overkill. +def make_parse_result( + config: RuntimeConfig, all_projects: Mapping[str, Project] +) -> ParseResult: + """Make a ParseResult from the project configuration and the profile.""" + # if any of these change, we need to reject the parser + vars_hash = FileHash.from_contents( + '\0'.join([ + getattr(config.args, 'vars', '{}') or '{}', + getattr(config.args, 'profile', '') or '', + getattr(config.args, 'target', '') or '', + __version__ + ]) + ) + profile_path = os.path.join(config.args.profiles_dir, 'profiles.yml') + with open(profile_path) as fp: + profile_hash = FileHash.from_contents(fp.read()) + + project_hashes = {} + for name, project in all_projects.items(): + path = os.path.join(project.project_root, 'dbt_project.yml') + with open(path) as fp: + project_hashes[name] = FileHash.from_contents(fp.read()) + + return ParseResult( + vars_hash=vars_hash, + profile_hash=profile_hash, + project_hashes=project_hashes, + ) class GraphLoader: - def __init__(self, root_project, all_projects): + def __init__( + self, root_project: RuntimeConfig, all_projects: Mapping[str, Project] + ) -> None: self.root_project = root_project self.all_projects = all_projects - self.nodes = {} - self.docs = {} - self.macros = {} - self.tests = {} - self.patches = {} - self.disabled = [] - self.macro_manifest = None - - def _load_sql_nodes(self, parser_type, resource_type, relative_dirs_attr, - **kwargs): - parser = parser_type(self.root_project, self.all_projects, - self.macro_manifest) - - for project_name, project in self.all_projects.items(): - parse_results = parser.load_and_parse( - package_name=project_name, - root_dir=project.project_root, - relative_dirs=getattr(project, relative_dirs_attr), - resource_type=resource_type, - **kwargs - ) - self.nodes.update(parse_results.parsed) - self.disabled.extend(parse_results.disabled) - def _load_macros(self, internal_manifest=None): - # skip any projects in the internal manifest - all_projects = self.all_projects.copy() - if internal_manifest is not None: - for name in internal_project_names(): - all_projects.pop(name, None) - self.macros.update(internal_manifest.macros) - - # give the macroparser all projects but then only load what we haven't - # loaded already - parser = MacroParser(self.root_project, self.all_projects) - for project_name, project in all_projects.items(): - self.macros.update(parser.load_and_parse( - package_name=project_name, - root_dir=project.project_root, - relative_dirs=project.macro_paths, - resource_type=NodeType.Macro, - )) - - def _load_seeds(self): - parser = SeedParser(self.root_project, self.all_projects, - self.macro_manifest) - for project_name, project in self.all_projects.items(): - self.nodes.update(parser.load_and_parse( - package_name=project_name, - root_dir=project.project_root, - relative_dirs=project.data_paths, - )) - - def _load_nodes(self): - self._load_sql_nodes(ModelParser, NodeType.Model, 'source_paths') - self._load_sql_nodes(SnapshotParser, NodeType.Snapshot, - 'snapshot_paths') - self._load_sql_nodes(AnalysisParser, NodeType.Analysis, - 'analysis_paths') - self._load_sql_nodes(DataTestParser, NodeType.Test, 'test_paths', - tags=['data']) - - hook_parser = HookParser(self.root_project, self.all_projects, - self.macro_manifest) - self.nodes.update(hook_parser.load_and_parse()) - - self._load_seeds() - - def _load_docs(self): - parser = DocumentationParser(self.root_project, self.all_projects) - for project_name, project in self.all_projects.items(): - self.docs.update(parser.load_and_parse( - package_name=project_name, - root_dir=project.project_root, - relative_dirs=project.docs_paths - )) - - def _load_schema_tests(self): - parser = SchemaParser(self.root_project, self.all_projects, - self.macro_manifest) - for project_name, project in self.all_projects.items(): - tests, patches, sources = parser.load_and_parse( - package_name=project_name, - root_dir=project.project_root, - relative_dirs=project.source_paths - ) + self.results = make_parse_result(root_project, all_projects) + self._loaded_file_cache: Dict[str, FileBlock] = {} - for unique_id, test in tests.items(): - if unique_id in self.tests: - dbt.exceptions.raise_duplicate_resource_name( - test, self.tests[unique_id], - ) - self.tests[unique_id] = test + def _load_macros( + self, + old_results: Optional[ParseResult], + internal_manifest: Optional[Manifest] = None, + ) -> None: + projects = self.all_projects + if internal_manifest is not None: + projects = { + k: v for k, v in self.all_projects.items() if k not in PACKAGES + } + self.results.macros.update(internal_manifest.macros) + self.results.files.update(internal_manifest.files) + + # TODO: go back to skipping the internal manifest during macro parsing + for project in projects.values(): + parser = MacroParser(self.results, project) + for path in parser.search(): + self.parse_with_cache(path, parser, old_results) + + def parse_with_cache( + self, + path: FilePath, + parser: BaseParser, + old_results: Optional[ParseResult], + ) -> None: + block = self._get_file(path, parser) + if not self._get_cached(block, old_results): + parser.parse_file(block) + + def _get_cached( + self, + block: FileBlock, + old_results: Optional[ParseResult], + ) -> bool: + # TODO: handle multiple parsers w/ same files, by + # tracking parser type vs node type? Or tracking actual + # parser type during parsing? + if old_results is None: + return False + if old_results.has_file(block.file): + return self.results.sanitized_update(block.file, old_results) + return False + + def _get_file(self, path: FilePath, parser: BaseParser) -> FileBlock: + if path.search_key in self._loaded_file_cache: + block = self._loaded_file_cache[path.search_key] + else: + block = FileBlock(file=parser.load_file(path)) + self._loaded_file_cache[path.search_key] = block + return block + + def parse_project( + self, + project: Project, + macro_manifest: Manifest, + old_results: Optional[ParseResult], + ) -> None: + parsers = [] + for cls in _parser_types: + parser = cls(self.results, project, self.root_project, + macro_manifest) + parsers.append(parser) + + # per-project cache. + self._loaded_file_cache.clear() + + for parser in parsers: + for path in parser.search(): + self.parse_with_cache(path, parser, old_results) + + def load_only_macros(self) -> Manifest: + old_results = self.read_parse_results() + self._load_macros(old_results, internal_manifest=None) + # make a manifest with just the macros to get the context + macro_manifest = Manifest.from_macros( + macros=self.results.macros, + files=self.results.files + ) + return macro_manifest - for unique_id, source in sources.items(): - if unique_id in self.nodes: - dbt.exceptions.raise_duplicate_resource_name( - source, self.nodes[unique_id], - ) - self.nodes[unique_id] = source + def load(self, internal_manifest: Optional[Manifest] = None): + old_results = self.read_parse_results() + self._load_macros(old_results, internal_manifest=internal_manifest) + # make a manifest with just the macros to get the context + macro_manifest = Manifest.from_macros( + macros=self.results.macros, + files=self.results.files + ) - for name, patch in patches.items(): - if name in self.patches: - dbt.exceptions.raise_duplicate_patch_name( - name, patch, self.patches[name] + for project in self.all_projects.values(): + # parse a single project + self.parse_project(project, macro_manifest, old_results) + + def write_parse_results(self): + path = os.path.join(self.root_project.target_path, + PARTIAL_PARSE_FILE_NAME) + make_directory(self.root_project.target_path) + with open(path, 'wb') as fp: + pickle.dump(self.results, fp) + + def _matching_parse_results(self, result: ParseResult) -> bool: + """Compare the global hashes of the read-in parse results' values to + the known ones, and return if it is ok to re-use the results. + """ + valid = True + if self.results.vars_hash != result.vars_hash: + logger.debug('vars hash mismatch, cache invalidated') + valid = False + if self.results.profile_hash != result.profile_hash: + logger.debug('profile hash mismatch, cache invalidated') + valid = False + + missing_keys = { + k for k in self.results.project_hashes + if k not in result.project_hashes + } + if missing_keys: + logger.debug( + 'project hash mismatch: values missing, cache invalidated: {}' + .format(missing_keys) + ) + valid = False + + for key, new_value in self.results.project_hashes.items(): + if key in result.project_hashes: + old_value = result.project_hashes[key] + if new_value != old_value: + logger.debug( + 'For key {}, hash mismatch ({} -> {}), cache ' + 'invalidated' + .format(key, old_value, new_value) ) - self.patches[name] = patch - - def load(self, internal_manifest=None): - self._load_macros(internal_manifest=internal_manifest) - # make a manifest with just the macros to get the context - self.macro_manifest = Manifest(macros=self.macros, nodes={}, docs={}, - generated_at=datetime.utcnow(), - disabled=[]) - self._load_nodes() - self._load_docs() - self._load_schema_tests() - - def create_manifest(self): + valid = False + return valid + + def read_parse_results(self) -> Optional[ParseResult]: + if not dbt.flags.PARTIAL_PARSE: + return None + path = os.path.join(self.root_project.target_path, + PARTIAL_PARSE_FILE_NAME) + + if os.path.exists(path): + try: + with open(path, 'rb') as fp: + result: ParseResult = pickle.load(fp) + # keep this check inside the try/except in case something about + # the file has changed in weird ways, perhaps due to being a + # different version of dbt + if self._matching_parse_results(result): + return result + except Exception as exc: + logger.debug( + 'Failed to load parsed file from disk at {}: {}' + .format(path, exc), + exc_info=True + ) + + return None + + def create_manifest(self) -> Manifest: + nodes: Dict[str, CompileResultNode] = {} + nodes.update(self.results.nodes) + nodes.update(self.results.sources) + disabled = [] + for value in self.results.disabled.values(): + disabled.extend(value) manifest = Manifest( - nodes=self.nodes, - macros=self.macros, - docs=self.docs, + nodes=nodes, + macros=self.results.macros, + docs=self.results.docs, generated_at=datetime.utcnow(), config=self.root_project, - disabled=self.disabled + disabled=disabled, + files=self.results.files, + ) + manifest.patch_nodes(self.results.patches) + manifest = ParserUtils.process_sources( + manifest, self.root_project.project_name + ) + manifest = ParserUtils.process_refs( + manifest, self.root_project.project_name + ) + manifest = ParserUtils.process_docs( + manifest, self.root_project.project_name ) - manifest.add_nodes(self.tests) - manifest.patch_nodes(self.patches) - manifest = ParserUtils.process_sources(manifest, self.root_project) - manifest = ParserUtils.process_refs(manifest, - self.root_project.project_name) - manifest = ParserUtils.process_docs(manifest, self.root_project) return manifest @classmethod - def _load_from_projects(cls, root_config, projects, internal_manifest): + def load_all( + cls, + root_config: RuntimeConfig, + internal_manifest: Optional[Manifest] = None + ) -> Manifest: + projects = load_all_projects(root_config) loader = cls(root_config, projects) loader.load(internal_manifest=internal_manifest) - return loader.create_manifest() - - @classmethod - def load_all(cls, root_config, internal_manifest=None): - projects = load_all_projects(root_config) - manifest = cls._load_from_projects(root_config, projects, - internal_manifest) + loader.write_parse_results() + manifest = loader.create_manifest() _check_manifest(manifest, root_config) return manifest @classmethod - def load_internal(cls, root_config): + def load_internal(cls, root_config: RuntimeConfig) -> Manifest: projects = load_internal_projects(root_config) - return cls._load_from_projects(root_config, projects, None) + loader = cls(root_config, projects) + return loader.load_only_macros() def _check_resource_uniqueness(manifest): @@ -248,7 +373,7 @@ def _project_directories(config): yield full_obj -def load_all_projects(config): +def load_all_projects(config) -> Mapping[str, Project]: all_projects = {config.project_name: config} project_paths = itertools.chain( internal_project_names(), diff --git a/core/dbt/main.py b/core/dbt/main.py index 92bc7be25a7..beb1a397f7d 100644 --- a/core/dbt/main.py +++ b/core/dbt/main.py @@ -92,8 +92,8 @@ def main(args=None): exit_code = e.code except BaseException as e: - logger.warn("Encountered an error:") - logger.warn(str(e)) + logger.warning("Encountered an error:") + logger.warning(str(e)) if logger_initialized(): logger.debug(traceback.format_exc()) @@ -359,6 +359,7 @@ def _build_compile_subparser(subparsers, base_subparser): "analysis files. \nCompiled SQL files are written to the target/" "directory.") sub.set_defaults(cls=compile_task.CompileTask, which='compile') + sub.add_argument('--parse-only', action='store_true') return sub @@ -655,6 +656,14 @@ def parse_args(args): help='''Display debug logging during dbt execution. Useful for debugging and making bug reports.''') + p.add_argument( + '--no-write-json', + action='store_false', + dest='write_json', + help='''If set, skip writing the manifest and run_results.json files to + disk''' + ) + p.add_argument( '-S', '--strict', @@ -670,6 +679,15 @@ def parse_args(args): configurations with no associated models, invalid test configurations, and missing sources/refs in tests''') + p.add_argument( + '--partial-parse', + action='store_true', + help='''Allow for partial parsing by looking for and writing to a + pickle file in the target directory. + WARNING: This can result in unexpected behavior if you use env_var()! + ''' + ) + # if set, run dbt in single-threaded mode: thread count is ignored, and # calls go through `map` instead of the thread pool. This is useful for # getting performance information about aspects of dbt that normally run in diff --git a/core/dbt/parser/__init__.py b/core/dbt/parser/__init__.py index cbecdaa1201..b5855f1fead 100644 --- a/core/dbt/parser/__init__.py +++ b/core/dbt/parser/__init__.py @@ -1,26 +1,17 @@ +from .analysis import AnalysisParser # noqa +from .base import Parser, ConfiguredParser # noqa +from .data_test import DataTestParser # noqa +from .docs import DocumentationParser # noqa +from .hooks import HookParser # noqa +from .macros import MacroParser # noqa +from .models import ModelParser # noqa +from .results import ParseResult # noqa +from .schemas import SchemaParser # noqa +from .seeds import SeedParser # noqa +from .snapshots import SnapshotParser # noqa +from .util import ParserUtils # noqa -from .analysis import AnalysisParser -from .snapshots import SnapshotParser -from .data_test import DataTestParser -from .docs import DocumentationParser -from .hooks import HookParser -from .macros import MacroParser -from .models import ModelParser -from .schemas import SchemaParser -from .seeds import SeedParser - -from .util import ParserUtils - -__all__ = [ - 'AnalysisParser', - 'SnapshotParser', - 'DataTestParser', - 'DocumentationParser', - 'HookParser', - 'MacroParser', - 'ModelParser', - 'SchemaParser', - 'SeedParser', - - 'ParserUtils', -] +from . import ( # noqa + analysis, base, data_test, docs, hooks, macros, models, results, schemas, + snapshots, util +) diff --git a/core/dbt/parser/analysis.py b/core/dbt/parser/analysis.py index b18f85cae1a..8d13d99368b 100644 --- a/core/dbt/parser/analysis.py +++ b/core/dbt/parser/analysis.py @@ -1,27 +1,24 @@ import os -from typing import Dict, Any -from dbt.contracts.graph.parsed import ParsedAnalysisNode, ParsedRPCNode -from dbt.parser.base_sql import BaseSqlParser +from dbt.contracts.graph.parsed import ParsedAnalysisNode +from dbt.node_types import NodeType +from dbt.parser.base import SimpleSQLParser +from dbt.parser.search import FilesystemSearcher, FileBlock -class AnalysisParser(BaseSqlParser): - @classmethod - def get_compiled_path(cls, name, relative_path): - return os.path.join('analysis', relative_path) - - def parse_from_dict( - self, - parsed_dict: Dict[str, Any] - ) -> ParsedAnalysisNode: - """Given a dictionary, return the parsed entity for this parser""" - return ParsedAnalysisNode.from_dict(parsed_dict) +class AnalysisParser(SimpleSQLParser[ParsedAnalysisNode]): + def get_paths(self): + return FilesystemSearcher( + self.project, self.project.analysis_paths, '.sql' + ) + def parse_from_dict(self, dct, validate=True) -> ParsedAnalysisNode: + return ParsedAnalysisNode.from_dict(dct, validate=validate) -class RPCCallParser(BaseSqlParser): - def get_compiled_path(cls, name, relative_path): - return os.path.join('rpc', relative_path) + @property + def resource_type(self) -> NodeType: + return NodeType.Analysis - def parse_from_dict(self, parsed_dict: Dict[str, Any]) -> ParsedRPCNode: - """Given a dictionary, return the parsed entity for this parser""" - return ParsedRPCNode.from_dict(parsed_dict) + @classmethod + def get_compiled_path(cls, block: FileBlock): + return os.path.join('analysis', block.path.relative_path) diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index 5ea1cc65d8b..fea9ba95fc4 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -1,78 +1,137 @@ import abc import os -from typing import Dict, Any +from typing import ( + List, Dict, Any, Callable, Iterable, Optional, Generic, TypeVar +) -import dbt.exceptions -import dbt.flags -import dbt.include -import dbt.utils -import dbt.hooks -import dbt.clients.jinja -import dbt.context.parser +from hologram import ValidationError -from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME -from dbt.utils import coalesce -from dbt.logger import GLOBAL_LOGGER as logger -from dbt.contracts.project import ProjectList -from dbt.parser.source_config import SourceConfig +import dbt.context.parser +import dbt.flags from dbt import deprecations from dbt import hooks +from dbt.clients.jinja import get_rendered +from dbt.config import Project, RuntimeConfig +from dbt.contracts.graph.manifest import ( + Manifest, SourceFile, FilePath, FileHash +) +from dbt.contracts.graph.parsed import HasUniqueID +from dbt.contracts.graph.unparsed import UnparsedNode +from dbt.exceptions import ( + CompilationException, validator_error_message +) +from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME +from dbt.node_types import NodeType, UnparsedNodeType +from dbt.source_config import SourceConfig +from dbt.parser.results import ParseResult, ManifestNodes +from dbt.parser.search import FileBlock +from dbt.clients.system import load_file_contents +# internally, the parser may store a less-restrictive type that will be +# transformed into the final type. But it will have to be derived from +# ParsedNode to be operable. +FinalValue = TypeVar('FinalValue', bound=HasUniqueID) +IntermediateValue = TypeVar('IntermediateValue', bound=HasUniqueID) -class BaseParser(metaclass=abc.ABCMeta): - def __init__(self, root_project_config, all_projects: ProjectList): - self.root_project_config = root_project_config - self.all_projects = all_projects - if dbt.flags.STRICT_MODE: - dct = { - 'projects': { - name: project.to_project_config(with_packages=True) - for name, project in all_projects.items() - } - } - ProjectList.from_dict(dct, validate=True) +IntermediateNode = TypeVar('IntermediateNode', bound=Any) +FinalNode = TypeVar('FinalNode', bound=ManifestNodes) - @property - def default_schema(self): - return getattr(self.root_project_config.credentials, 'schema', - 'public') - @property - def default_database(self): - return getattr(self.root_project_config.credentials, 'database', 'dbt') +RelationUpdate = Callable[[Optional[str], IntermediateNode], str] +ConfiguredBlockType = TypeVar('ConfiguredBlockType', bound=FileBlock) - def load_and_parse(self, *args, **kwargs): - raise dbt.exceptions.NotImplementedException("Not implemented") - @classmethod - def get_path(cls, resource_type, package_name, resource_name): - """Returns a unique identifier for a resource""" +class BaseParser(Generic[FinalValue]): + def __init__(self, results: ParseResult, project: Project) -> None: + self.results = results + self.project = project + # this should be a superset of [x.path for x in self.results.files] + # because we fill it via search() + self.searched: List[FilePath] = [] - return "{}.{}.{}".format(resource_type, package_name, resource_name) + @abc.abstractmethod + def get_paths(self) -> Iterable[FilePath]: + pass - @classmethod - def get_fqn(cls, node, package_project_config, extra=[]): - parts = dbt.utils.split_path(node.path) - name, _ = os.path.splitext(parts[-1]) - fqn = ([package_project_config.project_name] + - parts[:-1] + - extra + - [name]) + def search(self) -> List[FilePath]: + self.searched = list(self.get_paths()) + return self.searched - return fqn + @abc.abstractmethod + def parse_file(self, block: FileBlock) -> None: + pass + @abc.abstractproperty + def resource_type(self) -> NodeType: + pass -class MacrosKnownParser(BaseParser): - def __init__(self, root_project_config, all_projects, macro_manifest): - super().__init__( - root_project_config=root_project_config, - all_projects=all_projects - ) + def generate_unique_id(self, resource_name: str) -> str: + """Returns a unique identifier for a resource""" + return "{}.{}.{}".format(self.resource_type, + self.project.project_name, + resource_name) + + def load_file(self, path: FilePath) -> SourceFile: + file_contents = load_file_contents(path.absolute_path, strip=False) + checksum = FileHash.from_contents(file_contents) + source_file = SourceFile(path=path, checksum=checksum) + source_file.contents = file_contents.strip() + return source_file + + def parse_file_from_path(self, path: FilePath): + block = FileBlock(file=self.load_file(path)) + self.parse_file(block) + + +class Parser(BaseParser[FinalValue], Generic[FinalValue]): + def __init__( + self, + results: ParseResult, + project: Project, + root_project: RuntimeConfig, + macro_manifest: Manifest, + ) -> None: + super().__init__(results, project) + self.root_project = root_project self.macro_manifest = macro_manifest - self._get_schema_func = None - self._get_alias_func = None - def get_schema_func(self): + +class ConfiguredParser( + Parser[FinalNode], + Generic[ConfiguredBlockType, IntermediateNode, FinalNode], +): + def __init__( + self, + results: ParseResult, + project: Project, + root_project: RuntimeConfig, + macro_manifest: Manifest, + ) -> None: + super().__init__(results, project, root_project, macro_manifest) + self._get_schema_func: Optional[RelationUpdate] = None + self._get_alias_func: Optional[RelationUpdate] = None + + @abc.abstractclassmethod + def get_compiled_path(cls, block: ConfiguredBlockType): + pass + + @abc.abstractmethod + def parse_from_dict(self, dict, validate=True) -> IntermediateNode: + pass + + @abc.abstractproperty + def resource_type(self) -> NodeType: + pass + + @property + def default_schema(self): + return self.root_project.credentials.schema + + @property + def default_database(self): + return self.root_project.credentials.database + + def get_schema_func(self) -> RelationUpdate: """The get_schema function is set by a few different things: - if there is a 'generate_schema_name' macro in the root project, it will be used. @@ -87,7 +146,7 @@ def get_schema_func(self): get_schema_macro = self.macro_manifest.find_macro_by_name( 'generate_schema_name', - self.root_project_config.project_name + self.root_project.project_name ) if get_schema_macro is None: get_schema_macro = self.macro_manifest.find_macro_by_name( @@ -100,7 +159,7 @@ def get_schema(custom_schema_name=None, node=None): return self.default_schema else: root_context = dbt.context.parser.generate_macro( - get_schema_macro, self.root_project_config, + get_schema_macro, self.root_project, self.macro_manifest ) get_schema = get_schema_macro.generator(root_context) @@ -108,7 +167,7 @@ def get_schema(custom_schema_name=None, node=None): self._get_schema_func = get_schema return self._get_schema_func - def get_alias_func(self): + def get_alias_func(self) -> RelationUpdate: """The get_alias function is set by a few different things: - if there is a 'generate_alias_name' macro in the root project, it will be used. @@ -123,7 +182,7 @@ def get_alias_func(self): get_alias_macro = self.macro_manifest.find_macro_by_name( 'generate_alias_name', - self.root_project_config.project_name + self.root_project.project_name ) if get_alias_macro is None: get_alias_macro = self.macro_manifest.find_macro_by_name( @@ -140,7 +199,7 @@ def get_alias(custom_alias_name, node): return custom_alias_name else: root_context = dbt.context.parser.generate_macro( - get_alias_macro, self.root_project_config, + get_alias_macro, self.root_project, self.macro_manifest ) get_alias = get_alias_macro.generator(root_context) @@ -148,6 +207,16 @@ def get_alias(custom_alias_name, node): self._get_alias_func = get_alias return self._get_alias_func + def get_fqn(self, path: str, name: str) -> List[str]: + """Get the FQN for the node. This impacts node selection and config + application. + """ + no_ext = os.path.splitext(path)[0] + fqn = [self.project.project_name] + fqn.extend(dbt.utils.split_path(no_ext)[:-1]) + fqn.append(name) + return fqn + def _mangle_hooks(self, config): """Given a config dict that may have `pre-hook`/`post-hook` keys, convert it from the yucky maybe-a-string, maybe-a-dict to a dict. @@ -157,68 +226,89 @@ def _mangle_hooks(self, config): if key in config: config[key] = [hooks.get_hook_dict(h) for h in config[key]] - def _build_intermediate_node_dict(self, config, node_dict, node_path, - package_project_config, tags, fqn, - snapshot_config, column_name): - """Update the unparsed node dictionary and build the basis for an - intermediate ParsedNode that will be passed into the renderer + def _create_error_node( + self, name: str, path: str, original_file_path: str, raw_sql: str, + ) -> UnparsedNode: + """If we hit an error before we've actually parsed a node, provide some + level of useful information by attaching this to the exception. """ - # Set this temporarily. Not the full config yet (as config() hasn't - # been called from jinja yet). But the Var() call below needs info - # about project level configs b/c they might contain refs. - # TODO: Restructure this? - config_dict = coalesce(snapshot_config, {}) - config_dict.update(config.config) - self._mangle_hooks(config_dict) + # this is a bit silly, but build an UnparsedNode just for error + # message reasons + return UnparsedNode( + name=name, + resource_type=UnparsedNodeType(self.resource_type), + path=path, + original_file_path=original_file_path, + root_path=self.project.project_root, + package_name=self.project.project_name, + raw_sql=raw_sql, + ) - node_dict.update({ - 'refs': [], - 'sources': [], - 'depends_on': { - 'nodes': [], - 'macros': [], - }, - 'unique_id': node_path, - 'fqn': fqn, - 'tags': tags, - 'config': config_dict, - # Set these temporarily so get_rendered() has access to a schema, - # database, and alias. + def _create_parsetime_node( + self, + block: ConfiguredBlockType, + path: str, + config: SourceConfig, + name=None, + **kwargs, + ) -> IntermediateNode: + """Create the node that will be passed in to the parser context for + "rendering". Some information may be partial, as it'll be updated by + config() and any ref()/source() calls discovered during rendering. + """ + if name is None: + name = block.name + dct = { + 'alias': name, 'schema': self.default_schema, 'database': self.default_database, - 'alias': node_dict.get('name'), - }) - - # if there's a column, it should end up part of the ParsedNode - if column_name is not None: - node_dict['column_name'] = column_name - - return node_dict + 'fqn': config.fqn, + 'name': name, + 'root_path': self.project.project_root, + 'resource_type': self.resource_type, + 'path': path, + 'original_file_path': block.path.original_file_path, + 'package_name': self.project.project_name, + 'raw_sql': block.contents, + 'unique_id': self.generate_unique_id(name), + 'config': self.config_dict(config), + } + dct.update(kwargs) + try: + return self.parse_from_dict(dct) + except ValidationError as exc: + msg = validator_error_message(exc) + # this is a bit silly, but build an UnparsedNode just for error + # message reasons + node = self._create_error_node( + name=block.name, + path=path, + original_file_path=block.path.original_file_path, + raw_sql=block.contents, + ) + raise CompilationException(msg, node=node) - def _render_with_context(self, parsed_node, config): + def render_with_context( + self, parsed_node: IntermediateNode, config: SourceConfig + ) -> None: """Given the parsed node and a SourceConfig to use during parsing, render the node's sql wtih macro capture enabled. Note: this mutates the config object when config() calls are rendered. """ context = dbt.context.parser.generate( - parsed_node, - self.root_project_config, - self.macro_manifest, - config) + parsed_node, self.root_project, self.macro_manifest, config + ) - dbt.clients.jinja.get_rendered( - parsed_node.raw_sql, context, parsed_node, - capture_macros=True) + get_rendered(parsed_node.raw_sql, context, parsed_node, + capture_macros=True) - def _update_parsed_node_info(self, parsed_node, config): - """Given the SourceConfig used for parsing and the parsed node, - generate and set the true values to use, overriding the temporary parse - values set in _build_intermediate_parsed_node. - """ + def update_parsed_node_schema( + self, parsed_node: IntermediateNode, config_dict: Dict[str, Any] + ) -> None: # Special macro defined in the global project. Use the root project's # definition, not the current package - schema_override = config.config.get('schema') + schema_override = config_dict.get('schema') get_schema = self.get_schema_func() try: schema = get_schema(schema_override, parsed_node) @@ -230,77 +320,115 @@ def _update_parsed_node_info(self, parsed_node, config): if too_many_args not in str(exc): raise deprecations.warn('generate-schema-name-single-arg') - schema = get_schema(schema_override) + schema = get_schema(schema_override) # type: ignore parsed_node.schema = schema.strip() - alias_override = config.config.get('alias') + def update_parsed_node_alias( + self, parsed_node: IntermediateNode, config_dict: Dict[str, Any] + ) -> None: + alias_override = config_dict.get('alias') get_alias = self.get_alias_func() parsed_node.alias = get_alias(alias_override, parsed_node).strip() - parsed_node.database = config.config.get( + def update_parsed_node_config( + self, parsed_node: IntermediateNode, config_dict: Dict[str, Any] + ) -> None: + # Overwrite node config + final_config_dict = parsed_node.config.to_dict() + final_config_dict.update(config_dict) + # re-mangle hooks, in case we got new ones + self._mangle_hooks(final_config_dict) + parsed_node.config = parsed_node.config.from_dict(final_config_dict) + + def update_parsed_node( + self, parsed_node: IntermediateNode, config: SourceConfig + ) -> None: + """Given the SourceConfig used for parsing and the parsed node, + generate and set the true values to use, overriding the temporary parse + values set in _build_intermediate_parsed_node. + """ + config_dict = config.config + self.update_parsed_node_schema(parsed_node, config_dict) + self.update_parsed_node_alias(parsed_node, config_dict) + + parsed_node.database = config_dict.get( 'database', self.default_database ).strip() # Set tags on node provided in config blocks - model_tags = config.config.get('tags', []) + model_tags = config_dict.get('tags', []) parsed_node.tags.extend(model_tags) + self.update_parsed_node_config(parsed_node, config_dict) - # Overwrite node config - config_dict = parsed_node.config.to_dict() - config_dict.update(config.config) - # re-mangle hooks, in case we got new ones + def initial_config(self, fqn: List[str]) -> SourceConfig: + return SourceConfig(self.root_project, self.project, fqn, + self.resource_type) + + def config_dict(self, config: SourceConfig) -> Dict[str, Any]: + config_dict = config.config self._mangle_hooks(config_dict) - parsed_node.config = parsed_node.config.from_dict(config_dict) + return config_dict - @abc.abstractmethod - def parse_from_dict(self, parsed_dict: Dict[str, Any]) -> Any: - """Given a dictionary, return the parsed entity for this parser""" + def render_update( + self, node: IntermediateNode, config: SourceConfig + ) -> None: + try: + self.render_with_context(node, config) + self.update_parsed_node(node, config) + except ValidationError as exc: + # we got a ValidationError - probably bad types in config() + msg = validator_error_message(exc) + raise CompilationException(msg, node=node) from exc + + def add_result_node(self, block: FileBlock, node: ManifestNodes): + if node.config.enabled: + self.results.add_node(block.file, node) + else: + self.results.add_disabled(block.file, node) - def parse_node(self, node, node_path, package_project_config, tags=None, - fqn_extra=None, fqn=None, snapshot_config=None, - column_name=None): - """Parse a node, given an UnparsedNode and any other required information. + def parse_node(self, block: ConfiguredBlockType) -> FinalNode: + compiled_path = self.get_compiled_path(block) + fqn = self.get_fqn(compiled_path, block.name) - snapshot_config should be set if the node is an Snapshot node. - column_name should be set if the node is a Test node associated with a - particular column. - """ - logger.debug("Parsing {}".format(node_path)) + config = self.initial_config(fqn) + + node = self._create_parsetime_node( + block=block, + path=compiled_path, + config=config + ) + self.render_update(node, config) + result = self.transform(node) + self.add_result_node(block, result) + return result - tags = coalesce(tags, []) - fqn_extra = coalesce(fqn_extra, []) + @abc.abstractmethod + def parse_file(self, file_block: FileBlock) -> None: + pass - if fqn is None: - fqn = self.get_fqn(node, package_project_config, fqn_extra) + @abc.abstractmethod + def transform(self, node: IntermediateNode) -> FinalNode: + pass - config = SourceConfig( - self.root_project_config, - package_project_config, - fqn, - node.resource_type) - parsed_dict = self._build_intermediate_node_dict( - config, node.to_dict(), node_path, config, tags, fqn, - snapshot_config, column_name - ) - parsed_node = self.parse_from_dict(parsed_dict) +class SimpleParser( + ConfiguredParser[ConfiguredBlockType, FinalNode, FinalNode], + Generic[ConfiguredBlockType, FinalNode] +): + def transform(self, node): + return node - self._render_with_context(parsed_node, config) - self._update_parsed_node_info(parsed_node, config) - parsed_node.to_dict(validate=True) +class SQLParser( + ConfiguredParser[FileBlock, IntermediateNode, FinalNode], + Generic[IntermediateNode, FinalNode] +): + def parse_file(self, file_block: FileBlock) -> None: + self.parse_node(file_block) - return parsed_node - def check_block_parsing(self, name, path, contents): - """Check if we were able to extract toplevel blocks from the given - contents. Return True if extraction was successful (no exceptions), - False if it fails. - """ - if not dbt.flags.TEST_NEW_PARSER: - return True - try: - dbt.clients.jinja.extract_toplevel_blocks(contents) - except Exception: - return False - return True +class SimpleSQLParser( + SQLParser[FinalNode, FinalNode] +): + def transform(self, node): + return node diff --git a/core/dbt/parser/base_sql.py b/core/dbt/parser/base_sql.py deleted file mode 100644 index bc6a6b504ca..00000000000 --- a/core/dbt/parser/base_sql.py +++ /dev/null @@ -1,149 +0,0 @@ - -import os - -import dbt.contracts.project -import dbt.clients.system -import dbt.utils -import dbt.flags -from dbt.exceptions import ( - CompilationException, InternalException, NotImplementedException, - raise_duplicate_resource_name, validator_error_message -) - -from dbt.contracts.graph.unparsed import UnparsedNode -from dbt.parser.base import MacrosKnownParser -from dbt.node_types import NodeType - -from hologram import ValidationError - - -class BaseSqlParser(MacrosKnownParser): - UnparsedNodeType = UnparsedNode - - @classmethod - def get_compiled_path(cls, name, relative_path): - raise NotImplementedException("Not implemented") - - def load_and_parse(self, package_name, root_dir, relative_dirs, - resource_type, tags=None): - """Load and parse models in a list of directories. Returns a dict - that maps unique ids onto ParsedNodes""" - - extension = "[!.#~]*.sql" - - if tags is None: - tags = [] - - file_matches = dbt.clients.system.find_matching( - root_dir, - relative_dirs, - extension) - - result = [] - - for file_match in file_matches: - file_contents = dbt.clients.system.load_file_contents( - file_match.get('absolute_path')) - - parts = dbt.utils.split_path(file_match.get('relative_path', '')) - name, _ = os.path.splitext(parts[-1]) - - path = self.get_compiled_path(name, - file_match.get('relative_path')) - - original_file_path = os.path.join( - file_match.get('searched_path'), - file_match.get('relative_path')) - - result.append({ - 'name': name, - 'root_path': root_dir, - 'resource_type': resource_type, - 'path': path, - 'original_file_path': original_file_path, - 'package_name': package_name, - 'raw_sql': file_contents - }) - - return self.parse_sql_nodes(result, tags) - - def parse_sql_node(self, node_dict, tags=None): - if tags is None: - tags = [] - - node = self.UnparsedNodeType.from_dict(node_dict) - package_name = node.package_name - - unique_id = self.get_path(node.resource_type, - package_name, - node.name) - - project = self.all_projects.get(package_name) - - parse_ok = True - if node.resource_type == NodeType.Model: - parse_ok = self.check_block_parsing( - node.name, node.original_file_path, node.raw_sql - ) - - try: - node_parsed = self.parse_node(node, unique_id, project, tags=tags) - except ValidationError as exc: - # we got a ValidationError - probably bad types in config() - msg = validator_error_message(exc) - raise CompilationException(msg, node=node) from exc - - if not parse_ok: - # if we had a parse error in parse_node, we would not get here. So - # this means we rejected a good file :( - raise InternalException( - 'the block parser rejected a good node: {} was marked invalid ' - 'but is actually valid!'.format(node.original_file_path) - ) - return unique_id, node_parsed - - def parse_sql_nodes(self, nodes, tags=None): - if tags is None: - tags = [] - - results = SQLParseResult() - - for n in nodes: - node_path, node_parsed = self.parse_sql_node(n, tags) - - # Ignore disabled nodes - if not node_parsed.config.enabled: - results.disable(node_parsed) - continue - - results.keep(node_path, node_parsed) - - return results - - -class SQLParseResult: - def __init__(self): - self.parsed = {} - self.disabled = [] - - def result(self, unique_id, node): - if node.config['enabled']: - self.keep(unique_id, node) - else: - self.disable(node) - - def disable(self, node): - self.disabled.append(node) - - def keep(self, unique_id, node): - if unique_id in self.parsed: - raise_duplicate_resource_name( - self.parsed[unique_id], node - ) - - self.parsed[unique_id] = node - - def update(self, other): - self.disabled.extend(other.disabled) - for unique_id, node in other.parsed.items(): - self.keep(unique_id, node) diff --git a/core/dbt/parser/data_test.py b/core/dbt/parser/data_test.py index faab88f7cd1..c753bb26ef9 100644 --- a/core/dbt/parser/data_test.py +++ b/core/dbt/parser/data_test.py @@ -1,15 +1,29 @@ -from typing import Dict, Any - from dbt.contracts.graph.parsed import ParsedTestNode -from dbt.parser.base_sql import BaseSqlParser -import dbt.utils +from dbt.node_types import NodeType +from dbt.parser.base import SimpleSQLParser +from dbt.parser.search import FilesystemSearcher, FileBlock +from dbt.utils import get_pseudo_test_path -class DataTestParser(BaseSqlParser): - @classmethod - def get_compiled_path(cls, name, relative_path): - return dbt.utils.get_pseudo_test_path(name, relative_path, 'data_test') +class DataTestParser(SimpleSQLParser[ParsedTestNode]): + def get_paths(self): + return FilesystemSearcher( + self.project, self.project.test_paths, '.sql' + ) + + def parse_from_dict(self, dct, validate=True) -> ParsedTestNode: + return ParsedTestNode.from_dict(dct, validate=validate) - def parse_from_dict(self, parsed_dict: Dict[str, Any]) -> ParsedTestNode: - """Given a dictionary, return the parsed entity for this parser""" - return ParsedTestNode.from_dict(parsed_dict) + @property + def resource_type(self) -> NodeType: + return NodeType.Test + + def transform(self, node): + if 'data' not in node.tags: + node.tags.append('data') + return node + + @classmethod + def get_compiled_path(cls, block: FileBlock): + return get_pseudo_test_path(block.name, block.path.relative_path, + 'data_test') diff --git a/core/dbt/parser/docs.py b/core/dbt/parser/docs.py index 456850f8301..eb43a737854 100644 --- a/core/dbt/parser/docs.py +++ b/core/dbt/parser/docs.py @@ -1,82 +1,55 @@ -import os -from typing import Dict, Any +from typing import Iterable import jinja2.runtime -import dbt.exceptions -from dbt.parser.base import BaseParser +from dbt.clients.jinja import get_template from dbt.contracts.graph.unparsed import UnparsedDocumentationFile from dbt.contracts.graph.parsed import ParsedDocumentation -from dbt.clients.jinja import extract_toplevel_blocks, get_template -from dbt.clients import system +from dbt.exceptions import CompilationException, InternalException +from dbt.node_types import NodeType +from dbt.parser.base import Parser +from dbt.parser.search import ( + FullBlock, FileBlock, FilesystemSearcher, BlockSearcher +) +from dbt.utils import deep_merge, DOCS_PREFIX + + +class DocumentationParser(Parser[ParsedDocumentation]): + def get_paths(self): + return FilesystemSearcher( + project=self.project, + relative_dirs=self.project.docs_paths, + extension='.md', + ) + + @property + def resource_type(self) -> NodeType: + return NodeType.Documentation - -class DocumentationParser(BaseParser): @classmethod - def load_file(cls, package_name, root_dir, relative_dirs): - """Load and parse documentation in a list of projects. Returns a list - of ParsedNodes. - """ - extension = "[!.#~]*.md" - - file_matches = system.find_matching(root_dir, relative_dirs, extension) - - for file_match in file_matches: - file_contents = system.load_file_contents( - file_match.get('absolute_path'), - strip=False) - - parts = dbt.utils.split_path(file_match.get('relative_path', '')) - name, _ = os.path.splitext(parts[-1]) + def get_compiled_path(cls, block: FileBlock): + return block.path.relative_path - path = file_match.get('relative_path') - original_file_path = os.path.join( - file_match.get('searched_path'), - path) - - yield UnparsedDocumentationFile( - root_path=root_dir, - path=path, - original_file_path=original_file_path, - package_name=package_name, - file_contents=file_contents - ) - - def parse(self, docfile): - try: - blocks = extract_toplevel_blocks( - docfile.file_contents, - allowed_blocks={'docs'}, - collect_raw_data=False - ) - except dbt.exceptions.CompilationException as exc: - if exc.node is None: - exc.node = docfile - raise - - for block in blocks: - try: - template = get_template(block.full_block, {}) - except dbt.exceptions.CompilationException as e: - e.node = docfile - raise - yield from self._parse_template_docs(template, docfile) + def generate_unique_id(self, resource_name: str) -> str: + # because docs are in their own graph namespace, node type doesn't + # need to be part of the unique ID. + return '{}.{}'.format(self.project.project_name, resource_name) + # TODO: could this class just render() the tag.contents() and skip this + # whole extra module.__dict__.items() thing? def _parse_template_docs(self, template, docfile): for key, item in template.module.__dict__.items(): if type(item) != jinja2.runtime.Macro: continue - if not key.startswith(dbt.utils.DOCS_PREFIX): + if not key.startswith(DOCS_PREFIX): continue - name = key.replace(dbt.utils.DOCS_PREFIX, '') + name = key.replace(DOCS_PREFIX, '') - # because docs are in their own graph namespace, node type doesn't - # need to be part of the unique ID. - unique_id = '{}.{}'.format(docfile.package_name, name) + unique_id = self.generate_unique_id(name) - merged = dbt.utils.deep_merge( + merged = deep_merge( docfile.to_dict(), { 'name': name, @@ -86,19 +59,36 @@ def _parse_template_docs(self, template, docfile): ) yield ParsedDocumentation.from_dict(merged) - def load_and_parse(self, package_name, root_dir, relative_dirs): - to_return = {} - for docfile in self.load_file(package_name, root_dir, relative_dirs): - for parsed in self.parse(docfile): - if parsed.unique_id in to_return: - dbt.exceptions.raise_duplicate_resource_name( - to_return[parsed.unique_id], parsed - ) - to_return[parsed.unique_id] = parsed - return to_return - - def parse_from_dict( - self, parsed_dict: Dict[str, Any] - ) -> ParsedDocumentation: - """Given a dictionary, return the parsed entity for this parser""" - return ParsedDocumentation.from_dict(parsed_dict) + def parse_block(self, block: FullBlock) -> Iterable[ParsedDocumentation]: + base_node = UnparsedDocumentationFile( + root_path=self.project.project_root, + path=block.file.path.relative_path, + original_file_path=block.path.original_file_path, + package_name=self.project.project_name, + # set contents to the actual internal contents of the block + file_contents=block.contents, + ) + try: + template = get_template(block.contents, {}) + except CompilationException as e: + e.node = base_node + raise + all_docs = list(self._parse_template_docs(template, base_node)) + if len(all_docs) != 1: + raise InternalException( + 'Got {} docs in an extracted docs block: block parser ' + 'mismatched with jinja'.format(len(all_docs)) + ) + return all_docs + + def parse_file(self, file_block: FileBlock): + searcher: Iterable[FullBlock] = BlockSearcher( + source=[file_block], + allowed_blocks={'docs'}, + source_tag_factory=FullBlock, + ) + for block in searcher: + for parsed in self.parse_block(block): + self.results.add_doc(file_block.file, parsed) + # mark the file as seen, even if there are no macros in it + self.results.get_file(file_block.file) diff --git a/core/dbt/parser/hooks.py b/core/dbt/parser/hooks.py index 6f21750208e..d2c4e3a6199 100644 --- a/core/dbt/parser/hooks.py +++ b/core/dbt/parser/hooks.py @@ -1,81 +1,116 @@ -import collections -from typing import Dict, Any - -import dbt.flags -import dbt.contracts.project -import dbt.utils +import os +from dataclasses import dataclass +from typing import Iterable, Iterator, Union, List, Tuple +from dbt.contracts.graph.manifest import FilePath from dbt.contracts.graph.parsed import ParsedHookNode -from dbt.contracts.graph.unparsed import UnparsedRunHook -from dbt.parser.base_sql import BaseSqlParser +from dbt.exceptions import InternalException from dbt.node_types import NodeType, RunHookType +from dbt.source_config import SourceConfig +from dbt.parser.base import SimpleParser +from dbt.parser.search import FileBlock +from dbt.utils import get_pseudo_hook_path + + +@dataclass +class HookBlock(FileBlock): + project: str + value: str + index: int + hook_type: RunHookType + + @property + def contents(self): + return self.value + + @property + def name(self): + return '{}-{!s}-{!s}'.format(self.project, self.hook_type, self.index) + + +class HookSearcher(Iterable[HookBlock]): + def __init__(self, project, source_file, hook_type) -> None: + self.project = project + self.source_file = source_file + self.hook_type = hook_type + + def _hook_list( + self, hooks: Union[str, List[str], Tuple[str, ...]] + ) -> List[str]: + if isinstance(hooks, tuple): + hooks = list(hooks) + elif not isinstance(hooks, list): + hooks = [hooks] + return hooks - -class HookParser(BaseSqlParser): - UnparsedNodeType = UnparsedRunHook - - @classmethod - def get_hooks_from_project(cls, config, hook_type): - if hook_type == RunHookType.Start: - hooks = config.on_run_start - elif hook_type == RunHookType.End: - hooks = config.on_run_end + def get_hook_defs(self) -> List[str]: + if self.hook_type == RunHookType.Start: + hooks = self.project.on_run_start + elif self.hook_type == RunHookType.End: + hooks = self.project.on_run_end else: - dbt.exceptions.InternalException( - 'hook_type must be one of "{}" or "{}"' - .format(RunHookType.Start, RunHookType.End)) + raise InternalException( + 'hook_type must be one of "{}" or "{}" (got {})' + .format(RunHookType.Start, RunHookType.End, self.hook_type) + ) + return self._hook_list(hooks) + + def __iter__(self) -> Iterator[HookBlock]: + hooks = self.get_hook_defs() + for index, hook in enumerate(hooks): + yield HookBlock( + file=self.source_file, + project=self.project.project_name, + value=hook, + index=index, + hook_type=self.hook_type, + ) - if type(hooks) not in (list, tuple): - hooks = [hooks] - return hooks +class HookParser(SimpleParser[HookBlock, ParsedHookNode]): + def transform(self, node): + return node - def get_hooks(self, hook_type): - project_hooks = collections.defaultdict(list) - - for project_name, project in self.all_projects.items(): - hooks = self.get_hooks_from_project(project, hook_type) - project_hooks[project_name].extend(hooks) - - return project_hooks - - def load_and_parse_run_hook_type(self, hook_type): - project_hooks = self.get_hooks(hook_type) - - result = [] - for project_name, hooks in project_hooks.items(): - for i, hook in enumerate(hooks): - hook_name = '{}-{}-{}'.format(project_name, hook_type, i) - hook_path = dbt.utils.get_pseudo_hook_path(hook_name) - - result.append({ - 'name': hook_name, - 'root_path': "{}/dbt_project.yml".format(project_name), - 'resource_type': NodeType.Operation, - 'path': hook_path, - 'original_file_path': hook_path, - 'package_name': project_name, - 'raw_sql': hook, - 'index': i - }) - - # hook_type is a RunHookType member, which "is a string", but it's also - # an enum, so hologram gets mad about that before even looking at if - # it's a string - bypass it by explicitly calling str(). - tags = [str(hook_type)] - results = self.parse_sql_nodes(result, tags=tags) - return results.parsed - - def load_and_parse(self): - hook_nodes = {} - for hook_type in RunHookType: - project_hooks = self.load_and_parse_run_hook_type( - hook_type, - ) - hook_nodes.update(project_hooks) + def get_paths(self): + searched_path = '.' + relative_path = 'dbt_project.yml' + absolute_path = os.path.normcase(os.path.abspath(os.path.join( + self.project.project_root, searched_path, relative_path + ))) + path = FilePath( + searched_path='.', + relative_path='relative_path', + absolute_path=absolute_path, + ) + return [path] - return hook_nodes + def parse_from_dict(self, dct, validate=True) -> ParsedHookNode: + return ParsedHookNode.from_dict(dct, validate=validate) - def parse_from_dict(self, parsed_dict: Dict[str, Any]) -> ParsedHookNode: - """Given a dictionary, return the parsed entity for this parser""" - return ParsedHookNode.from_dict(parsed_dict) + @classmethod + def get_compiled_path(cls, block: HookBlock): + return get_pseudo_hook_path(block.name) + + def _create_parsetime_node( + self, + block: HookBlock, + path: str, + config: SourceConfig, + name=None, + **kwargs, + ) -> ParsedHookNode: + + return super()._create_parsetime_node( + block=block, path=path, config=config, + index=block.index, name=name, + tags=[str(block.hook_type)] + ) + + @property + def resource_type(self) -> NodeType: + return NodeType.Operation + + def parse_file(self, block: FileBlock) -> None: + for hook_type in RunHookType: + for hook in HookSearcher(self.project, block.file, hook_type): + self.parse_node(hook) diff --git a/core/dbt/parser/macros.py b/core/dbt/parser/macros.py index ab20d23ff18..b80505688b3 100644 --- a/core/dbt/parser/macros.py +++ b/core/dbt/parser/macros.py @@ -1,109 +1,85 @@ -import os -from typing import Dict, Any +from typing import Iterable -import jinja2.runtime +import jinja2 -import dbt.exceptions -import dbt.flags -import dbt.utils - -import dbt.clients.jinja -import dbt.clients.system -import dbt.contracts.project - -from dbt.parser.base import BaseParser -from dbt.node_types import NodeType -from dbt.logger import GLOBAL_LOGGER as logger +from dbt.clients import jinja from dbt.contracts.graph.unparsed import UnparsedMacro from dbt.contracts.graph.parsed import ParsedMacro +from dbt.exceptions import CompilationException +from dbt.logger import GLOBAL_LOGGER as logger +from dbt.node_types import NodeType, MacroType +from dbt.parser.base import BaseParser +from dbt.parser.search import FileBlock, FilesystemSearcher +from dbt.utils import MACRO_PREFIX -class MacroParser(BaseParser): - def parse_macro_file(self, macro_file_path, macro_file_contents, root_path, - package_name, resource_type, tags=None, context=None): - - logger.debug("Parsing {}".format(macro_file_path)) - - to_return = {} - - if tags is None: - tags = [] +class MacroParser(BaseParser[ParsedMacro]): + def get_paths(self): + return FilesystemSearcher( + project=self.project, + relative_dirs=self.project.macro_paths, + extension='.sql', + ) - base_node = UnparsedMacro( - path=macro_file_path, - original_file_path=macro_file_path, - package_name=package_name, - raw_sql=macro_file_contents, - root_path=root_path, - resource_type=resource_type, + @property + def resource_type(self) -> NodeType: + return NodeType.Macro + + @classmethod + def get_compiled_path(cls, block: FileBlock): + return block.path.relative_path + + def parse_macro(self, base_node: UnparsedMacro, name: str) -> ParsedMacro: + unique_id = self.generate_unique_id(name) + + return ParsedMacro( + path=base_node.path, + original_file_path=base_node.original_file_path, + package_name=base_node.package_name, + raw_sql=base_node.raw_sql, + root_path=base_node.root_path, + resource_type=base_node.resource_type, + name=name, + unique_id=unique_id, ) + def parse_unparsed_macros( + self, base_node: UnparsedMacro + ) -> Iterable[ParsedMacro]: try: - ast = dbt.clients.jinja.parse(macro_file_contents) - except dbt.exceptions.CompilationException as e: + ast = jinja.parse(base_node.raw_sql) + except CompilationException as e: e.node = base_node raise e for macro_node in ast.find_all(jinja2.nodes.Macro): macro_name = macro_node.name - node_type = None - if macro_name.startswith(dbt.utils.MACRO_PREFIX): - node_type = NodeType.Macro - name = macro_name.replace(dbt.utils.MACRO_PREFIX, '') - - if node_type != resource_type: + if not macro_name.startswith(MACRO_PREFIX): continue - unique_id = self.get_path(resource_type, package_name, name) - - merged = dbt.utils.deep_merge( - base_node.to_dict(), - { - 'name': name, - 'unique_id': unique_id, - 'tags': tags, - 'depends_on': {'macros': []}, - }) - - new_node = ParsedMacro.from_dict(merged) - - to_return[unique_id] = new_node - - return to_return - - def parse_from_dict(self, parsed_dict: Dict[str, Any]) -> ParsedMacro: - return ParsedMacro.from_dict(parsed_dict) + name = macro_name.replace(MACRO_PREFIX, '') + node = self.parse_macro(base_node, name) + yield node - def load_and_parse(self, package_name, root_dir, relative_dirs, - resource_type, tags=None): - extension = "[!.#~]*.sql" + def parse_file(self, block: FileBlock): + # mark the file as seen, even if there are no macros in it + self.results.get_file(block.file) + source_file = block.file - if tags is None: - tags = [] + original_file_path = source_file.path.original_file_path - file_matches = dbt.clients.system.find_matching( - root_dir, - relative_dirs, - extension) + logger.debug("Parsing {}".format(original_file_path)) - result = {} - - for file_match in file_matches: - file_contents = dbt.clients.system.load_file_contents( - file_match.get('absolute_path')) - - original_file_path = os.path.join( - file_match.get('searched_path'), - file_match.get('relative_path') - ) - - result.update( - self.parse_macro_file( - original_file_path, - file_contents, - root_dir, - package_name, - resource_type)) + # this is really only used for error messages + base_node = UnparsedMacro( + path=original_file_path, + original_file_path=original_file_path, + package_name=self.project.project_name, + raw_sql=source_file.contents, + root_path=self.project.project_root, + resource_type=MacroType(NodeType.Macro), + ) - return result + for node in self.parse_unparsed_macros(base_node): + self.results.add_macro(block.file, node) diff --git a/core/dbt/parser/models.py b/core/dbt/parser/models.py index 1f8f4e6a2a1..339004d267e 100644 --- a/core/dbt/parser/models.py +++ b/core/dbt/parser/models.py @@ -1,14 +1,22 @@ -from typing import Dict, Any - from dbt.contracts.graph.parsed import ParsedModelNode -from dbt.parser.base_sql import BaseSqlParser +from dbt.node_types import NodeType +from dbt.parser.base import SimpleSQLParser +from dbt.parser.search import FilesystemSearcher, FileBlock -class ModelParser(BaseSqlParser): - @classmethod - def get_compiled_path(cls, name, relative_path): - return relative_path +class ModelParser(SimpleSQLParser[ParsedModelNode]): + def get_paths(self): + return FilesystemSearcher( + self.project, self.project.source_paths, '.sql' + ) + + def parse_from_dict(self, dct, validate=True) -> ParsedModelNode: + return ParsedModelNode.from_dict(dct, validate=validate) - def parse_from_dict(self, parsed_dict: Dict[str, Any]) -> ParsedModelNode: - """Given a dictionary, return the parsed entity for this parser""" - return ParsedModelNode.from_dict(parsed_dict) + @property + def resource_type(self) -> NodeType: + return NodeType.Model + + @classmethod + def get_compiled_path(cls, block: FileBlock): + return block.path.relative_path diff --git a/core/dbt/parser/results.py b/core/dbt/parser/results.py new file mode 100644 index 00000000000..437c0f26a83 --- /dev/null +++ b/core/dbt/parser/results.py @@ -0,0 +1,197 @@ +from dataclasses import dataclass, field +from typing import TypeVar, MutableMapping, Mapping, Union, List + +from hologram import JsonSchemaMixin + +from dbt.contracts.graph.manifest import SourceFile, RemoteFile, FileHash +from dbt.contracts.graph.parsed import ( + ParsedNode, HasUniqueID, ParsedMacro, ParsedDocumentation, ParsedNodePatch, + ParsedSourceDefinition, ParsedAnalysisNode, ParsedHookNode, ParsedRPCNode, + ParsedModelNode, ParsedSeedNode, ParsedTestNode, ParsedSnapshotNode, +) +from dbt.contracts.util import Writable +from dbt.exceptions import ( + raise_duplicate_resource_name, raise_duplicate_patch_name, + CompilationException, InternalException +) + + +# Parsers can return anything as long as it's a unique ID +ParsedValueType = TypeVar('ParsedValueType', bound=HasUniqueID) + + +def _check_duplicates( + value: HasUniqueID, src: Mapping[str, HasUniqueID] +): + if value.unique_id in src: + raise_duplicate_resource_name(value, src[value.unique_id]) + + +ManifestNodes = Union[ + ParsedAnalysisNode, + ParsedHookNode, + ParsedModelNode, + ParsedSeedNode, + ParsedTestNode, + ParsedSnapshotNode, + ParsedRPCNode, +] + + +def dict_field(): + return field(default_factory=dict) + + +@dataclass +class ParseResult(JsonSchemaMixin, Writable): + vars_hash: FileHash + profile_hash: FileHash + project_hashes: MutableMapping[str, FileHash] + nodes: MutableMapping[str, ManifestNodes] = dict_field() + sources: MutableMapping[str, ParsedSourceDefinition] = dict_field() + docs: MutableMapping[str, ParsedDocumentation] = dict_field() + macros: MutableMapping[str, ParsedMacro] = dict_field() + patches: MutableMapping[str, ParsedNodePatch] = dict_field() + files: MutableMapping[str, SourceFile] = dict_field() + disabled: MutableMapping[str, List[ParsedNode]] = dict_field() + + def get_file(self, source_file: SourceFile) -> SourceFile: + key = source_file.search_key + if key is None: + return source_file + if key not in self.files: + self.files[key] = source_file + return self.files[key] + + def add_source( + self, source_file: SourceFile, node: ParsedSourceDefinition + ): + # nodes can't be overwritten! + _check_duplicates(node, self.sources) + self.sources[node.unique_id] = node + self.get_file(source_file).sources.append(node.unique_id) + + def add_node(self, source_file: SourceFile, node: ManifestNodes): + # nodes can't be overwritten! + _check_duplicates(node, self.nodes) + self.nodes[node.unique_id] = node + self.get_file(source_file).nodes.append(node.unique_id) + + def add_disabled(self, source_file: SourceFile, node: ParsedNode): + if node.unique_id in self.disabled: + self.disabled[node.unique_id].append(node) + else: + self.disabled[node.unique_id] = [node] + self.get_file(source_file).nodes.append(node.unique_id) + + def add_macro(self, source_file: SourceFile, macro: ParsedMacro): + # macros can be overwritten (should they be?) + self.macros[macro.unique_id] = macro + self.get_file(source_file).macros.append(macro.unique_id) + + def add_doc(self, source_file: SourceFile, doc: ParsedDocumentation): + # Docs also can be overwritten (should they be?) + self.docs[doc.unique_id] = doc + self.get_file(source_file).docs.append(doc.unique_id) + + def add_patch(self, source_file: SourceFile, patch: ParsedNodePatch): + # matches can't be overwritten + if patch.name in self.patches: + raise_duplicate_patch_name(patch.name, patch, + self.patches[patch.name]) + self.patches[patch.name] = patch + self.get_file(source_file).patches.append(patch.name) + + def _get_disabled( + self, unique_id: str, match_file: SourceFile + ) -> List[ParsedNode]: + if unique_id not in self.disabled: + raise InternalException( + 'called _get_disabled with id={}, but it does not exist' + .format(unique_id) + ) + return [ + n for n in self.disabled[unique_id] + if n.original_file_path == match_file.path.original_file_path + ] + + def sanitized_update( + self, source_file: SourceFile, old_result: 'ParseResult', + ) -> bool: + """Perform a santized update. If the file can't be updated, invalidate + it and return false. + """ + if isinstance(source_file.path, RemoteFile): + return False + + old_file = old_result.get_file(source_file) + for doc_id in old_file.docs: + doc = _expect_value(doc_id, old_result.docs, old_file, "docs") + self.add_doc(source_file, doc) + + for macro_id in old_file.macros: + macro = _expect_value( + macro_id, old_result.macros, old_file, "macros" + ) + self.add_macro(source_file, macro) + + for source_id in old_file.sources: + source = _expect_value( + source_id, old_result.sources, old_file, "sources" + ) + self.add_source(source_file, source) + + # because we know this is how we _parsed_ the node, we can safely + # assume if it's disabled it was done by the project or file, and + # we can keep our old data + for node_id in old_file.nodes: + if node_id in old_result.nodes: + node = old_result.nodes[node_id] + self.add_node(source_file, node) + elif node_id in old_result.disabled: + matches = old_result._get_disabled(node_id, source_file) + for match in matches: + self.add_disabled(source_file, match) + else: + raise CompilationException( + 'Expected to find "{}" in cached "manifest.nodes" or ' + '"manifest.disabled" based on cached file information: {}!' + .format(node_id, old_file) + ) + + for name in old_file.patches: + patch = _expect_value( + name, old_result.patches, old_file, "patches" + ) + self.add_patch(source_file, patch) + + return True + + def has_file(self, source_file: SourceFile) -> bool: + key = source_file.search_key + if key is None: + return False + if key not in self.files: + return False + my_checksum = self.files[key].checksum + return my_checksum == source_file.checksum + + @classmethod + def rpc(cls): + # ugh! + return cls(FileHash.empty(), FileHash.empty(), {}) + + +T = TypeVar('T') + + +def _expect_value( + key: str, src: Mapping[str, T], old_file: SourceFile, name: str +) -> T: + if key not in src: + raise CompilationException( + 'Expected to find "{}" in cached "result.{}" based ' + 'on cached file information: {}!' + .format(key, name, old_file) + ) + return src[key] diff --git a/core/dbt/parser/rpc.py b/core/dbt/parser/rpc.py new file mode 100644 index 00000000000..3b3e6a4833d --- /dev/null +++ b/core/dbt/parser/rpc.py @@ -0,0 +1,62 @@ +import os +from dataclasses import dataclass +from typing import Iterable + +from dbt.contracts.graph.manifest import SourceFile +from dbt.contracts.graph.parsed import ParsedRPCNode, ParsedMacro +from dbt.contracts.graph.unparsed import UnparsedMacro +from dbt.exceptions import InternalException +from dbt.node_types import NodeType, MacroType +from dbt.parser.base import SimpleSQLParser +from dbt.parser.macros import MacroParser +from dbt.parser.search import FileBlock + + +@dataclass +class RPCBlock(FileBlock): + rpc_name: str + + @property + def name(self): + return self.rpc_name + + +class RPCCallParser(SimpleSQLParser[ParsedRPCNode]): + def get_paths(self): + return [] + + def parse_from_dict(self, dct, validate=True) -> ParsedRPCNode: + return ParsedRPCNode.from_dict(dct, validate=validate) + + @property + def resource_type(self) -> NodeType: + return NodeType.RPCCall + + def get_compiled_path(cls, block: FileBlock): + # we do it this way to make mypy happy + if not isinstance(block, RPCBlock): + raise InternalException( + 'While parsing RPC calls, got an actual file block instead of ' + 'an RPC block: {}'.format(block) + ) + + return os.path.join('rpc', block.name) + + def parse_remote(self, sql: str, name: str) -> ParsedRPCNode: + source_file = SourceFile.remote(contents=sql) + contents = RPCBlock(rpc_name=name, file=source_file) + return self.parse_node(contents) + + +class RPCMacroParser(MacroParser): + def parse_remote(self, contents) -> Iterable[ParsedMacro]: + base = UnparsedMacro( + path='from remote system', + original_file_path='from remote system', + package_name=self.project.project_name, + raw_sql=contents, + root_path=self.project.project_root, + resource_type=MacroType(NodeType.Macro), + ) + for node in self.parse_unparsed_macros(base): + yield node diff --git a/core/dbt/parser/schema_test_builders.py b/core/dbt/parser/schema_test_builders.py new file mode 100644 index 00000000000..84ee7e22583 --- /dev/null +++ b/core/dbt/parser/schema_test_builders.py @@ -0,0 +1,299 @@ +import hashlib +import re +from dataclasses import dataclass +from typing import Generic, TypeVar, Dict, Any, Tuple, Optional, List, Union + +from dbt.clients.jinja import get_rendered +from dbt.contracts.graph.unparsed import ( + UnparsedNodeUpdate, UnparsedSourceDefinition, + UnparsedSourceTableDefinition, NamedTested +) +from dbt.exceptions import raise_compiler_error +from dbt.parser.search import FileBlock + + +def get_nice_schema_test_name( + test_type: str, test_name: str, args: Dict[str, Any] +) -> Tuple[str, str]: + flat_args = [] + for arg_name in sorted(args): + arg_val = args[arg_name] + + if isinstance(arg_val, dict): + parts = list(arg_val.values()) + elif isinstance(arg_val, (list, tuple)): + parts = list(arg_val) + else: + parts = [arg_val] + + flat_args.extend([str(part) for part in parts]) + + clean_flat_args = [re.sub('[^0-9a-zA-Z_]+', '_', arg) for arg in flat_args] + unique = "__".join(clean_flat_args) + + cutoff = 32 + if len(unique) <= cutoff: + label = unique + else: + label = hashlib.md5(unique.encode('utf-8')).hexdigest() + + filename = '{}_{}_{}'.format(test_type, test_name, label) + name = '{}_{}_{}'.format(test_type, test_name, unique) + + return filename, name + + +def as_kwarg(key: str, value: Any) -> str: + test_value = str(value) + is_function = re.match(r'^\s*(env_var|ref|var|source|doc)\s*\(.+\)\s*$', + test_value) + + # if the value is a function, don't wrap it in quotes! + if is_function: + formatted_value = value + else: + formatted_value = value.__repr__() + + return "{key}={value}".format(key=key, value=formatted_value) + + +@dataclass +class YamlBlock(FileBlock): + data: Dict[str, Any] + + @classmethod + def from_file_block(cls, src: FileBlock, data: Dict[str, Any]): + return cls( + file=src.file, + data=data, + ) + + +@dataclass +class SourceTarget: + source: UnparsedSourceDefinition + table: UnparsedSourceTableDefinition + + @property + def name(self) -> str: + return '{0.name}_{1.name}'.format(self.source, self.table) + + @property + def columns(self) -> List[NamedTested]: + if self.table.columns is None: + return [] + else: + return self.table.columns + + @property + def tests(self) -> List[Union[Dict[str, Any], str]]: + if self.table.tests is None: + return [] + else: + return self.table.tests + + +ModelTarget = UnparsedNodeUpdate + + +Target = TypeVar('Target', ModelTarget, SourceTarget) + + +@dataclass +class TargetBlock(YamlBlock, Generic[Target]): + target: Target + + @property + def name(self): + return self.target.name + + @property + def columns(self): + if self.target.columns is None: + return [] + else: + return self.target.columns + + @property + def tests(self) -> List[Union[Dict[str, Any], str]]: + if self.target.tests is None: + return [] + else: + return self.target.tests + + @classmethod + def from_yaml_block( + cls, src: YamlBlock, target: Target + ) -> 'TargetBlock[Target]': + return cls( + file=src.file, + data=src.data, + target=target, + ) + + +@dataclass +class SchemaTestBlock(TargetBlock): + test: Dict[str, Any] + column_name: Optional[str] + + @classmethod + def from_target_block( + cls, src: TargetBlock, test: Dict[str, Any], column_name: Optional[str] + ) -> 'SchemaTestBlock': + return cls( + file=src.file, + data=src.data, + target=src.target, + test=test, + column_name=column_name + ) + + +class TestBuilder(Generic[Target]): + """An object to hold assorted test settings and perform basic parsing + + Test names have the following pattern: + - the test name itself may be namespaced (package.test) + - or it may not be namespaced (test) + - the test may have arguments embedded in the name (, severity=WARN) + - or it may not have arguments. + + """ + TEST_NAME_PATTERN = re.compile( + r'((?P([a-zA-Z_][0-9a-zA-Z_]*))\.)?' + r'(?P([a-zA-Z_][0-9a-zA-Z_]*))' + ) + # map magic keys to default values + MODIFIER_ARGS = {'severity': 'ERROR'} + + def __init__( + self, + test: Dict[str, Any], + target: Target, + package_name: str, + render_ctx: Dict[str, Any], + column_name: str = None, + ) -> None: + test_name, test_args = self.extract_test_args(test, column_name) + self.args: Dict[str, Any] = test_args + self.package_name: str = package_name + self.target: Target = target + + match = self.TEST_NAME_PATTERN.match(test_name) + if match is None: + raise_compiler_error( + 'Test name string did not match expected pattern: {}' + .format(test_name) + ) + + groups = match.groupdict() + self.name: str = groups['test_name'] + self.namespace: str = groups['test_namespace'] + self.modifiers: Dict[str, Any] = {} + for key, default in self.MODIFIER_ARGS.items(): + value = self.args.pop(key, default) + if isinstance(value, str): + value = get_rendered(value, render_ctx) + self.modifiers[key] = value + + if self.namespace is not None: + self.package_name = self.namespace + + compiled_name, fqn_name = self.get_test_name() + self.compiled_name: str = compiled_name + self.fqn_name: str = fqn_name + + def _bad_type(self) -> TypeError: + return TypeError('invalid target type "{}"'.format(type(self.target))) + + @staticmethod + def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]: + if not isinstance(test, dict): + raise_compiler_error( + 'test must be dict or str, got {} (value {})'.format( + type(test), test + ) + ) + + test = list(test.items()) + if len(test) != 1: + raise_compiler_error( + 'test definition dictionary must have exactly one key, got' + ' {} instead ({} keys)'.format(test, len(test)) + ) + test_name, test_args = test[0] + + if not isinstance(test_args, dict): + raise_compiler_error( + 'test arguments must be dict, got {} (value {})'.format( + type(test_args), test_args + ) + ) + if not isinstance(test_name, str): + raise_compiler_error( + 'test name must be a str, got {} (value {})'.format( + type(test_name), test_name + ) + ) + if name is not None: + test_args['column_name'] = name + return test_name, test_args + + def severity(self) -> str: + return self.modifiers.get('severity', 'ERROR').upper() + + def test_kwargs_str(self) -> str: + # sort the dict so the keys are rendered deterministically (for tests) + return ', '.join(( + as_kwarg(key, self.args[key]) + for key in sorted(self.args) + )) + + def macro_name(self) -> str: + macro_name = 'test_{}'.format(self.name) + if self.namespace is not None: + macro_name = "{}.{}".format(self.namespace, macro_name) + return macro_name + + def describe_test_target(self) -> str: + if isinstance(self.target, ModelTarget): + fmt = "model('{0}')" + elif isinstance(self.target, SourceTarget): + fmt = "source('{0.source}', '{0.table}')" + else: + raise self._bad_type() + return fmt.format(self.target) + + raise NotImplementedError('describe_test_target not implemented!') + + def get_test_name(self) -> Tuple[str, str]: + if isinstance(self.target, ModelTarget): + name = self.name + elif isinstance(self.target, SourceTarget): + name = 'source_' + self.name + else: + raise self._bad_type() + if self.namespace is not None: + name = '{}_{}'.format(self.namespace, name) + return get_nice_schema_test_name(name, self.target.name, self.args) + + def build_raw_sql(self) -> str: + return ( + "{{{{ config(severity='{severity}') }}}}" + "{{{{ {macro}(model={model}, {kwargs}) }}}}" + ).format( + model=self.build_model_str(), + macro=self.macro_name(), + kwargs=self.test_kwargs_str(), + severity=self.severity() + ) + + def build_model_str(self): + if isinstance(self.target, ModelTarget): + fmt = "ref('{0.name}')" + elif isinstance(self.target, SourceTarget): + fmt = "source('{0.source.name}', '{0.table.name}')" + else: + raise self._bad_type() + return fmt.format(self.target) diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index ddf9f49b226..33df4516cdd 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -1,250 +1,55 @@ -from __future__ import unicode_literals -import itertools import os -import re -import hashlib -from typing import Optional +from typing import Iterable, Dict, Any, Union, List, Optional from hologram import ValidationError -import dbt.exceptions -import dbt.flags -import dbt.utils - -import dbt.clients.yaml_helper -import dbt.context.parser -import dbt.contracts.project - from dbt.context.common import generate_config_context + from dbt.clients.jinja import get_rendered -from dbt.node_types import NodeType +from dbt.clients.yaml_helper import load_yaml_text +from dbt.config.renderer import ConfigRenderer +from dbt.contracts.graph.manifest import SourceFile +from dbt.contracts.graph.parsed import ( + ParsedNodePatch, ParsedTestNode, ParsedSourceDefinition, ColumnInfo, Docref +) +from dbt.contracts.graph.unparsed import ( + UnparsedSourceDefinition, UnparsedNodeUpdate, NamedTested, + UnparsedSourceTableDefinition, FreshnessThreshold +) +from dbt.context.parser import docs +from dbt.exceptions import ( + warn_or_error, validator_error_message, JSONValidationException, + raise_invalid_schema_yml_version, ValidationException, CompilationException +) from dbt.logger import GLOBAL_LOGGER as logger +from dbt.node_types import NodeType, SourceType +from dbt.parser.base import SimpleParser +from dbt.parser.search import FileBlock, FilesystemSearcher +from dbt.parser.schema_test_builders import ( + TestBuilder, SourceTarget, ModelTarget, Target, + SchemaTestBlock, TargetBlock, YamlBlock, +) from dbt.utils import get_pseudo_test_path -from dbt.contracts.graph.unparsed import UnparsedNode, UnparsedNodeUpdate, \ - UnparsedSourceDefinition, UnparsedSourceTableDefinition, FreshnessThreshold -from dbt.contracts.graph.parsed import ParsedNodePatch, ParsedTestNode, \ - ParsedSourceDefinition, ParsedNode, ColumnInfo, Docref -from dbt.parser.base import MacrosKnownParser -from dbt.config.renderer import ConfigRenderer -from dbt.exceptions import JSONValidationException, validator_error_message - -from typing import Dict, List - - -def get_nice_schema_test_name(test_type, test_name, args): - flat_args = [] - for arg_name in sorted(args): - arg_val = args[arg_name] - - if isinstance(arg_val, dict): - parts = arg_val.values() - elif isinstance(arg_val, (list, tuple)): - parts = arg_val - else: - parts = [arg_val] - - flat_args.extend([str(part) for part in parts]) - - clean_flat_args = [re.sub('[^0-9a-zA-Z_]+', '_', arg) for arg in flat_args] - unique = "__".join(clean_flat_args) - - cutoff = 32 - if len(unique) <= cutoff: - label = unique - else: - label = hashlib.md5(unique.encode('utf-8')).hexdigest() - - filename = '{}_{}_{}'.format(test_type, test_name, label) - name = '{}_{}_{}'.format(test_type, test_name, unique) - - return filename, name - - -def as_kwarg(key, value): - test_value = str(value) - is_function = re.match(r'^\s*(env_var|ref|var|source|doc)\s*\(.+\)\s*$', - test_value) - - # if the value is a function, don't wrap it in quotes! - if is_function: - formatted_value = value - else: - formatted_value = value.__repr__() - - return "{key}={value}".format(key=key, value=formatted_value) - - -class TestBuilder: - """An object to hold assorted test settings and perform basic parsing - - Test names have the following pattern: - - the test name itself may be namespaced (package.test) - - or it may not be namespaced (test) - - the test may have arguments embedded in the name (, severity=WARN) - - or it may not have arguments. - - """ - TEST_NAME_PATTERN = re.compile( - r'((?P([a-zA-Z_][0-9a-zA-Z_]*))\.)?' - r'(?P([a-zA-Z_][0-9a-zA-Z_]*))' - ) - # map magic keys to default values - MODIFIER_ARGS = {'severity': 'ERROR'} - - def __init__(self, test, target, column_name, package_name, render_ctx): - test_name, test_args = self.extract_test_args(test, column_name) - self.args = test_args - self.package_name = package_name - self.target = target - - match = self.TEST_NAME_PATTERN.match(test_name) - if match is None: - dbt.exceptions.raise_compiler_error( - 'Test name string did not match expected pattern: {}' - .format(test_name) - ) - - groups = match.groupdict() - self.name = groups['test_name'] - self.namespace = groups['test_namespace'] - self.modifiers = {} - for key, default in self.MODIFIER_ARGS.items(): - value = self.args.pop(key, default) - if isinstance(value, str): - value = get_rendered(value, render_ctx) - self.modifiers[key] = value - - if self.namespace is not None: - self.package_name = self.namespace - - @staticmethod - def extract_test_args(test, name=None): - if not isinstance(test, dict): - dbt.exceptions.raise_compiler_error( - 'test must be dict or str, got {} (value {})'.format( - type(test), test - ) - ) - - test = list(test.items()) - if len(test) != 1: - dbt.exceptions.raise_compiler_error( - 'test definition dictionary must have exactly one key, got' - ' {} instead ({} keys)'.format(test, len(test)) - ) - test_name, test_args = test[0] - - if not isinstance(test_args, dict): - dbt.exceptions.raise_compiler_error( - 'test arguments must be dict, got {} (value {})'.format( - type(test_args), test_args - ) - ) - if not isinstance(test_name, str): - dbt.exceptions.raise_compiler_error( - 'test name must be a str, got {} (value {})'.format( - type(test_name), test_name - ) - ) - if name is not None: - test_args['column_name'] = name - return test_name, test_args - - def severity(self): - return self.modifiers.get('severity', 'ERROR').upper() - - def test_kwargs_str(self): - # sort the dict so the keys are rendered deterministically (for tests) - return ', '.join(( - as_kwarg(key, self.args[key]) - for key in sorted(self.args) - )) - - def macro_name(self): - macro_name = 'test_{}'.format(self.name) - if self.namespace is not None: - macro_name = "{}.{}".format(self.namespace, macro_name) - return macro_name - - def build_model_str(self): - raise NotImplementedError('build_model_str not implemented!') - - def get_test_name(self): - raise NotImplementedError('get_test_name not implemented!') - - def build_raw_sql(self): - return ( - "{{{{ config(severity='{severity}') }}}}" - "{{{{ {macro}(model={model}, {kwargs}) }}}}" - ).format( - model=self.build_model_str(), - macro=self.macro_name(), - kwargs=self.test_kwargs_str(), - severity=self.severity() - ) - - -class RefTestBuilder(TestBuilder): - def build_model_str(self): - return "ref('{}')".format(self.target.name) - - def get_test_name(self): - return get_nice_schema_test_name(self.name, - self.target.name, - self.args) - - def describe_test_target(self): - return 'model "{}"'.format(self.target) - -class SourceTestBuilder(TestBuilder): - def build_model_str(self): - return "source('{}', '{}')".format( - self.target['source'].name, - self.target['table'].name - ) - def get_test_name(self): - target_name = '{}_{}'.format(self.target['source'].name, - self.target['table'].name) - return get_nice_schema_test_name( - 'source_' + self.name, - target_name, - self.args - ) +UnparsedSchemaYaml = Union[UnparsedSourceDefinition, UnparsedNodeUpdate] - def describe_test_target(self): - return 'source "{0[source]}.{0[table]}"'.format(self.target) +TestDef = Union[str, Dict[str, Any]] def warn_invalid(filepath, key, value, explain): msg = ( "Invalid test config given in {} @ {}: {} {}" ).format(filepath, key, value, explain) - dbt.exceptions.warn_or_error(msg, value, - log_fmt='Compilation warning: {}\n') + warn_or_error(msg, value, log_fmt='Compilation warning: {}\n') -def _filter_validate(filepath, location, values, validate): - """Generator for validate() results called against all given values. On - errors, fields are warned about and ignored, unless strict mode is set in - which case a compiler error is raised. - """ - for value in values: - if not isinstance(value, dict): - warn_invalid(filepath, location, value, '(expected a dict)') - continue - try: - yield validate(value) - # we don't want to fail the full run, but we do want to fail - # parsing this file - except ValidationError as exc: - msg = validator_error_message(exc) - warn_invalid(filepath, location, value, '- ' + msg) - continue - except JSONValidationException as exc: - warn_invalid(filepath, location, value, '- ' + exc.msg) - continue +def warn_validation_error(filepath, key, value, exc): + if isinstance(exc, ValidationError): + msg = validator_error_message(exc) + else: + msg = exc.msg + warn_invalid(filepath, key, value, '- ' + msg) class ParserRef: @@ -258,190 +63,225 @@ def add(self, column_name, description): description=description) -class SchemaBaseTestParser(MacrosKnownParser): - Builder = TestBuilder - - def _parse_column(self, target, column, package_name, root_dir, path, - refs): - # this should yield ParsedNodes where resource_type == NodeType.Test - column_name = column.name - description = column.description - - refs.add(column_name, description) - context = { - 'doc': dbt.context.parser.docs(target, refs.docrefs, column_name) - } +def collect_docrefs( + target: UnparsedSchemaYaml, + refs: ParserRef, + column_name: Optional[str], + *descriptions: str, +) -> None: + context = {'doc': docs(target, refs.docrefs, column_name)} + for description in descriptions: get_rendered(description, context) - for test in column.tests: - try: - yield self.build_test_node( - target, package_name, test, root_dir, - path, column_name - ) - except dbt.exceptions.CompilationException as exc: - dbt.exceptions.warn_or_error( - 'Compilation warning: Invalid test config given in {}:' - '\n\t{}'.format(path, exc.msg), None - ) - continue - - def parse_from_dict(self, parsed_dict) -> ParsedTestNode: - return ParsedTestNode.from_dict(parsed_dict) - - def build_test_node(self, test_target, package_name, test, root_dir, path, - column_name=None): - """Build a test node against the given target (a model or a source). - - :param test_target: An unparsed form of the target. - """ - if isinstance(test, str): - test = {test: {}} - - ctx = generate_config_context(self.root_project_config.cli_vars) - test_info = self.Builder(test, test_target, column_name, package_name, - ctx) +def _trimmed(inp: str) -> str: + if len(inp) < 50: + return inp + return inp[:44] + '...' + inp[-3:] - source_package = self.all_projects.get(test_info.package_name) - if source_package is None: - desc = '"{}" test on {}'.format( - test_info.name, test_info.describe_test_target() - ) - dbt.exceptions.raise_dep_not_found(None, desc, test_info.namespace) - - test_path = os.path.basename(path) - hashed_name, full_name = test_info.get_test_name() +class SchemaParser(SimpleParser[SchemaTestBlock, ParsedTestNode]): + """ + The schema parser is really big because schemas are really complicated! + + There are basically three phases to the schema parser: + - read_yaml_{models,sources}: read in yaml as a dictionary, then + validate it against the basic structures required so we can start + parsing (ModelTarget, SourceTarget) + - these return potentially many Targets per yaml block, since earch + source can have multiple tables + - parse_target_{model,source}: Read in the underlying target, parse and + return a list of all its tests (model and column tests), collect + any refs/descriptions, and return a parsed entity with the + appropriate information. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._renderer = ConfigRenderer(self.root_project.cli_vars) - hashed_path = get_pseudo_test_path(hashed_name, test_path, - 'schema_test') + @classmethod + def get_compiled_path(cls, block: FileBlock) -> str: + # should this raise an error? + return block.path.relative_path - full_path = get_pseudo_test_path(full_name, test_path, 'schema_test') - raw_sql = test_info.build_raw_sql() + @property + def resource_type(self) -> NodeType: + return NodeType.Test - unparsed = UnparsedNode( - name=full_name, - resource_type=NodeType.Test, - package_name=test_info.package_name, - root_path=root_dir, - path=hashed_path, - original_file_path=path, - raw_sql=raw_sql + def get_paths(self): + return FilesystemSearcher( + self.project, self.project.source_paths, '.yml' ) - # supply our own fqn which overrides the hashed version from the path - # TODO: is this necessary even a little bit for tests? - fqn_override = self.get_fqn(unparsed.replace(path=full_path), - source_package) - - node_path = self.get_path(NodeType.Test, unparsed.package_name, - unparsed.name) - - result = self.parse_node(unparsed, - node_path, - source_package, - tags=['schema'], - fqn_extra=None, - fqn=fqn_override, - column_name=column_name) - - parse_ok = self.check_block_parsing(full_name, test_path, raw_sql) - if not parse_ok: - # if we had a parse error in parse_node, we would not get here. So - # this means we rejected a good file :( - raise dbt.exceptions.InternalException( - 'the block parser rejected a good node: {} was marked invalid ' - 'but is actually valid!'.format(test_path) - ) - return result + def parse_from_dict(self, dct, validate=True) -> ParsedTestNode: + return ParsedTestNode.from_dict(dct, validate=validate) + def _parse_format_version( + self, yaml: YamlBlock + ) -> None: + path = yaml.path.relative_path + if 'version' not in yaml.data: + raise_invalid_schema_yml_version(path, 'no version is specified') -class SchemaModelParser(SchemaBaseTestParser): - Builder = RefTestBuilder + version = yaml.data['version'] + # if it's not an integer, the version is malformed, or not + # set. Either way, only 'version: 2' is supported. + if not isinstance(version, int): + raise_invalid_schema_yml_version( + path, 'the version is not an integer' + ) + if version != 2: + raise_invalid_schema_yml_version( + path, 'version {} is not supported'.format(version) + ) - def parse_models_entry(self, model, path, package_name, root_dir): - model_name = model.name - refs = ParserRef() - for column in model.columns: - column_tests = self._parse_column(model, column, package_name, - root_dir, path, refs) - for node in column_tests: - yield 'test', node + def _get_dicts_for( + self, yaml: YamlBlock, key: str + ) -> Iterable[Dict[str, Any]]: + data = yaml.data.get(key, []) + if not isinstance(data, list): + raise CompilationException( + '{} must be a list, got {} instead: ({})' + .format(key, type(data), _trimmed(str(data))) + ) + path = yaml.path.original_file_path - for test in model.tests: + for entry in data: + str_keys = ( + isinstance(entry, dict) and + all(isinstance(k, str) for k in entry) + ) + if str_keys: + yield entry + else: + warn_invalid(path, key, entry, '(expected a Dict[str])') + + def read_yaml_models( + self, yaml: YamlBlock + ) -> Iterable[ModelTarget]: + path = yaml.path.original_file_path + yaml_key = 'models' + + for data in self._get_dicts_for(yaml, yaml_key): try: - node = self.build_test_node(model, package_name, test, - root_dir, path) - except dbt.exceptions.CompilationException as exc: - dbt.exceptions.warn_or_error( - 'Compilation warning: Invalid test config given in {}:' - '\n\t{}'.format(path, exc.msg), None - ) - continue - yield 'test', node - - context = {'doc': dbt.context.parser.docs(model, refs.docrefs)} - description = model.description - get_rendered(description, context) + model = UnparsedNodeUpdate.from_dict(data) + # we don't want to fail the full run, but we do want to fail + # parsing this block + except (ValidationError, JSONValidationException) as exc: + warn_validation_error(path, yaml_key, data, exc) + else: + yield model + + def read_yaml_sources( + self, yaml: YamlBlock + ) -> Iterable[SourceTarget]: + path = yaml.path.original_file_path + yaml_key = 'sources' + + for data in self._get_dicts_for(yaml, yaml_key): + try: + data = self._renderer.render_schema_source(data) + source = UnparsedSourceDefinition.from_dict(data) + except (ValidationError, JSONValidationException) as exc: + warn_validation_error(path, yaml_key, data, exc) + else: + for table in source.tables: + yield SourceTarget(source, table) + + def _yaml_from_file( + self, source_file: SourceFile + ) -> Optional[Dict[str, Any]]: + """If loading the yaml fails, the file will be skipped with an INFO + message. TODO(jeb): should this be a warning? + """ + path: str = source_file.path.relative_path + try: + return load_yaml_text(source_file.contents) + except ValidationException as e: + logger.info("Error reading {}:{} - Skipping\n{}".format( + self.project.project_name, path, e)) + return None + + def parse_column( + self, block: TargetBlock, column: NamedTested, refs: ParserRef + ) -> None: + column_name = column.name + description = column.description + collect_docrefs(block.target, refs, column_name, description) - patch = ParsedNodePatch( - name=model_name, - original_file_path=path, - description=description, - columns=refs.column_info, - docrefs=refs.docrefs - ) - yield 'patch', patch + refs.add(column_name, description) - def parse_all(self, models, path, package_name, root_dir): - """Parse all the model dictionaries in models. + if not column.tests: + return - :param List[dict] models: The `models` section of the schema.yml, as a - list of dicts. - :param str path: The path to the schema.yml file - :param str package_name: The name of the current package - :param str root_dir: The root directory of the search + for test in column.tests: + self.parse_test(block, test, column_name) + + def parse_node(self, block: SchemaTestBlock) -> ParsedTestNode: + """In schema parsing, we rewrite most of the part of parse_node that + builds the initial node to be parsed, but rendering is basically the + same """ - filtered = _filter_validate(path, 'models', models, - UnparsedNodeUpdate.from_dict) - nodes = itertools.chain.from_iterable( - self.parse_models_entry(model, path, package_name, root_dir) - for model in filtered + render_ctx = generate_config_context(self.root_project.cli_vars) + builder = TestBuilder[Target]( + test=block.test, + target=block.target, + column_name=block.column_name, + package_name=self.project.project_name, + render_ctx=render_ctx, ) - for node_type, node in nodes: - yield node_type, node - - -class SchemaSourceParser(SchemaBaseTestParser): - Builder = SourceTestBuilder - def __init__(self, root_project_config, all_projects, macro_manifest): - super().__init__( - root_project_config=root_project_config, - all_projects=all_projects, - macro_manifest=macro_manifest + original_name = os.path.basename(block.path.original_file_path) + compiled_path = get_pseudo_test_path( + builder.compiled_name, original_name, 'schema_test', ) - self._renderer = ConfigRenderer(self.root_project_config.cli_vars) - - def _build_raw_sql(self, test_info): - return test_info.build_source_test_raw_sql() - - def _generate_test_name(self, test_info): - target_name = '{}_{}'.format(test_info.target['source']['name'], - test_info.target['table']['name']) - return get_nice_schema_test_name( - 'source_' + test_info.name, - target_name, - test_info.args + fqn_path = get_pseudo_test_path( + builder.fqn_name, original_name, 'schema_test', ) + # the fqn for tests actually happens in the test target's name, which + # is not necessarily this package's name + fqn = self.get_fqn(fqn_path, builder.fqn_name) + + config = self.initial_config(fqn) + + node = self._create_parsetime_node( + block=block, + path=compiled_path, + config=config, + tags=['schema'], + name=builder.fqn_name, + raw_sql=builder.build_raw_sql(), + column_name=block.column_name, + ) + self.render_update(node, config) + self.add_result_node(block, node) + return node + + def parse_test( + self, + target_block: TargetBlock, + test: TestDef, + column_name: Optional[str] + ) -> None: - @staticmethod - def _describe_test_target(test_target): - return 'source "{0[source]}.{0[table]}"'.format(test_target) + if isinstance(test, str): + test = {test: {}} - def get_path(self, *parts): - return '.'.join(str(s) for s in parts) + block = SchemaTestBlock.from_target_block( + src=target_block, + test=test, + column_name=column_name + ) + try: + self.parse_node(block) + except CompilationException as exc: + context = _trimmed(str(block.target)) + msg = ( + 'Compilation warning: Invalid test config given in {}:' + '\n\t{}\n\t@: {}' + .format(block.path.original_file_path, exc.msg, context) + ) + warn_or_error(msg, None) def _calculate_freshness( self, @@ -459,29 +299,40 @@ def _calculate_freshness( else: return None - def generate_source_node(self, source, table, path, package_name, root_dir, - refs): - unique_id = self.get_path(NodeType.Source, package_name, - source.name, table.name) - - context = {'doc': dbt.context.parser.docs(source, refs.docrefs)} + def parse_tests(self, block: TargetBlock) -> ParserRef: + refs = ParserRef() + for column in block.columns: + self.parse_column(block, column, refs) + + for test in block.tests: + self.parse_test(block, test, None) + return refs + + def generate_source_node( + self, block: TargetBlock, refs: ParserRef + ) -> ParsedSourceDefinition: + assert isinstance(block.target, SourceTarget) + source = block.target.source + table = block.target.table + unique_id = '.'.join([ + NodeType.Source, self.project.project_name, source.name, table.name + ]) description = table.description or '' source_description = source.description or '' - get_rendered(description, context) - get_rendered(source_description, context) + collect_docrefs(source, refs, None, description, source_description) loaded_at_field = table.loaded_at_field or source.loaded_at_field freshness = self._calculate_freshness(source, table) quoting = source.quoting.merged(table.quoting) + path = block.path.original_file_path - default_database = self.root_project_config.credentials.database return ParsedSourceDefinition( - package_name=package_name, - database=(source.database or default_database), + package_name=self.project.project_name, + database=(source.database or self.default_database), schema=(source.schema or source.name), identifier=(table.identifier or table.name), - root_path=root_dir, + root_path=self.project.project_root, path=path, original_file_path=path, columns=refs.column_info, @@ -495,164 +346,61 @@ def generate_source_node(self, source, table, path, package_name, root_dir, loaded_at_field=loaded_at_field, freshness=freshness, quoting=quoting, - resource_type=NodeType.Source, - fqn=[package_name, source.name, table.name] - ) - - def parse_source_table(self, source, table, path, package_name, root_dir): - refs = ParserRef() - test_target = {'source': source, 'table': table} - for column in table.columns: - column_tests = self._parse_column(test_target, column, - package_name, root_dir, path, - refs) - for node in column_tests: - yield 'test', node - - for test in table.tests: - try: - node = self.build_test_node(test_target, package_name, test, - root_dir, path) - except dbt.exceptions.CompilationException as exc: - dbt.exceptions.warn_or_error( - 'in {}: {}'.format(path, exc.msg), test - ) - continue - yield 'test', node - - node = self.generate_source_node(source, table, path, package_name, - root_dir, refs) - yield 'source', node - - def parse_source_entry(self, source, path, package_name, root_dir): - nodes = itertools.chain.from_iterable( - self.parse_source_table(source, table, path, package_name, - root_dir) - for table in source.tables - ) - for node_type, node in nodes: - yield node_type, node - - def _sources_validate(self, kwargs): - kwargs = self._renderer.render_schema_source(kwargs) - return UnparsedSourceDefinition.from_dict(kwargs) - - def parse_all(self, sources, path, package_name, root_dir): - """Parse all the model dictionaries in sources. - - :param List[dict] sources: The `sources` section of the schema.yml, as - a list of dicts. - :param str path: The path to the schema.yml file - :param str package_name: The name of the current package - :param str root_dir: The root directory of the search - """ - filtered = _filter_validate(path, 'sources', sources, - self._sources_validate) - nodes = itertools.chain.from_iterable( - self.parse_source_entry(source, path, package_name, root_dir) - for source in filtered + resource_type=SourceType(NodeType.Source), + fqn=[self.project.project_name, source.name, table.name], ) - for node_type, node in nodes: - yield node_type, node + def generate_node_patch( + self, block: TargetBlock, refs: ParserRef + ) -> ParsedNodePatch: + assert isinstance(block.target, UnparsedNodeUpdate) + description = block.target.description + collect_docrefs(block.target, refs, None, description) - -class SchemaParser: - def __init__(self, root_project_config, all_projects, macro_manifest): - self.root_project_config = root_project_config - self.all_projects = all_projects - self.macro_manifest = macro_manifest - - @classmethod - def find_schema_yml(cls, package_name, root_dir, relative_dirs): - """This is common to both v1 and v2 - look through the relative_dirs - under root_dir for .yml files yield pairs of filepath and loaded yaml - contents. - """ - extension = "[!.#~]*.yml" - - file_matches = dbt.clients.system.find_matching( - root_dir, - relative_dirs, - extension) - - for file_match in file_matches: - file_contents = dbt.clients.system.load_file_contents( - file_match.get('absolute_path'), strip=False) - test_path = file_match.get('relative_path', '') - - original_file_path = os.path.join(file_match.get('searched_path'), - test_path) - - try: - test_yml = dbt.clients.yaml_helper.load_yaml_text( - file_contents - ) - except dbt.exceptions.ValidationException as e: - test_yml = None - logger.info("Error reading {}:{} - Skipping\n{}".format( - package_name, test_path, e)) - - if test_yml is None: - continue - - yield original_file_path, test_yml - - def parse_schema(self, path, test_yml, package_name, root_dir): - model_parser = SchemaModelParser(self.root_project_config, - self.all_projects, - self.macro_manifest) - source_parser = SchemaSourceParser(self.root_project_config, - self.all_projects, - self.macro_manifest) - models = test_yml.get('models', []) - sources = test_yml.get('sources', []) - return itertools.chain( - model_parser.parse_all(models, path, package_name, root_dir), - source_parser.parse_all(sources, path, package_name, root_dir), + return ParsedNodePatch( + name=block.target.name, + original_file_path=block.path.original_file_path, + description=description, + columns=refs.column_info, + docrefs=refs.docrefs ) - def _parse_format_version(self, path, test_yml): - if 'version' not in test_yml: - dbt.exceptions.raise_invalid_schema_yml_version( - path, 'no version is specified' - ) - - version = test_yml['version'] - # if it's not an integer, the version is malformed, or not - # set. Either way, only 'version: 2' is supported. - if not isinstance(version, int): - dbt.exceptions.raise_invalid_schema_yml_version( - path, 'the version is not an integer' - ) - return version - - def load_and_parse(self, package_name, root_dir, relative_dirs): - new_tests: Dict[str, ParsedNode] = {} - node_patches: Dict[str, dict] = {} - new_sources: Dict[str, ParsedSourceDefinition] = {} - - iterator = self.find_schema_yml(package_name, root_dir, relative_dirs) - - for path, test_yml in iterator: - version = self._parse_format_version(path, test_yml) - if version != 2: - dbt.exceptions.raise_invalid_schema_yml_version( - path, - 'version {} is not supported'.format(version) - ) - - results = self.parse_schema(path, test_yml, package_name, root_dir) - for result_type, node in results: - if result_type == 'patch': - node_patches[node.name] = node - elif result_type == 'test': - new_tests[node.unique_id] = node - elif result_type == 'source': - new_sources[node.unique_id] = node - else: - raise dbt.exceptions.InternalException( - 'Got invalid result type {} '.format(result_type) - ) - - return new_tests, node_patches, new_sources + def parse_target_model( + self, target_block: TargetBlock[UnparsedNodeUpdate] + ) -> ParsedNodePatch: + refs = self.parse_tests(target_block) + patch = self.generate_node_patch(target_block, refs) + return patch + + def parse_target_source( + self, target_block: TargetBlock[SourceTarget] + ) -> ParsedSourceDefinition: + refs = self.parse_tests(target_block) + patch = self.generate_source_node(target_block, refs) + return patch + + def parse_yaml_models(self, yaml_block: YamlBlock): + for node in self.read_yaml_models(yaml_block): + node_block = TargetBlock.from_yaml_block(yaml_block, node) + patch = self.parse_target_model(node_block) + self.results.add_patch(yaml_block.file, patch) + + def parse_yaml_sources( + self, yaml_block: YamlBlock + ): + for source in self.read_yaml_sources(yaml_block): + source_block = TargetBlock.from_yaml_block(yaml_block, source) + source_table = self.parse_target_source(source_block) + self.results.add_source(yaml_block.file, source_table) + + def parse_file(self, block: FileBlock) -> None: + dct = self._yaml_from_file(block.file) + # mark the file as seen, even if there are no macros in it + self.results.get_file(block.file) + if dct: + yaml_block = YamlBlock.from_file_block(block, dct) + + self._parse_format_version(yaml_block) + + self.parse_yaml_models(yaml_block) + self.parse_yaml_sources(yaml_block) diff --git a/core/dbt/parser/search.py b/core/dbt/parser/search.py new file mode 100644 index 00000000000..ec165016e25 --- /dev/null +++ b/core/dbt/parser/search.py @@ -0,0 +1,118 @@ +import os +from dataclasses import dataclass +from typing import ( + List, Callable, Iterable, Set, Union, Iterator, TypeVar, Generic +) + +from dbt.clients.jinja import extract_toplevel_blocks, BlockTag +from dbt.clients.system import find_matching +from dbt.config import Project +from dbt.contracts.graph.manifest import SourceFile, FilePath +from dbt.exceptions import CompilationException + + +@dataclass +class FileBlock: + file: SourceFile + + @property + def name(self): + base = os.path.basename(self.file.path.relative_path) + name, _ = os.path.splitext(base) + return name + + @property + def contents(self): + return self.file.contents + + @property + def path(self): + return self.file.path + + +@dataclass +class BlockContents(FileBlock): + block: BlockTag + + @property + def name(self): + return self.block.block_name + + @property + def contents(self): + return self.block.contents + + +@dataclass +class FullBlock(FileBlock): + block: BlockTag + + @property + def name(self): + return self.block.block_name + + @property + def contents(self): + return self.block.full_block + + +class FilesystemSearcher(Iterable[FilePath]): + def __init__( + self, project: Project, relative_dirs: List[str], extension: str + ) -> None: + self.project = project + self.relative_dirs = relative_dirs + self.extension = extension + + def __iter__(self) -> Iterator[FilePath]: + ext = "[!.#~]*" + self.extension + + root = self.project.project_root + + for result in find_matching(root, self.relative_dirs, ext): + file_match = FilePath(**{ + k: os.path.normcase(v) for k, v in result.items() + }) + yield file_match + + +Block = Union[BlockContents, FullBlock] + +BlockSearchResult = TypeVar('BlockSearchResult', BlockContents, FullBlock) + +BlockSearchResultFactory = Callable[[SourceFile, BlockTag], BlockSearchResult] + + +class BlockSearcher(Generic[BlockSearchResult], Iterable[BlockSearchResult]): + def __init__( + self, + source: List[FileBlock], + allowed_blocks: Set[str], + source_tag_factory: BlockSearchResultFactory + ) -> None: + self.source = source + self.allowed_blocks = allowed_blocks + self.source_tag_factory: BlockSearchResultFactory = source_tag_factory + + def extract_blocks(self, source_file: FileBlock) -> Iterable[BlockTag]: + try: + blocks = extract_toplevel_blocks( + source_file.contents, + allowed_blocks=self.allowed_blocks, + collect_raw_data=False + ) + # this makes mypy happy, and this is an invariant we really need + for block in blocks: + assert isinstance(block, BlockTag) + yield block + + except CompilationException as exc: + if exc.node is None: + # TODO(jeb): attach info about resource type/file path here + exc.node = NotImplemented + raise + + def __iter__(self) -> Iterator[BlockSearchResult]: + for entry in self.source: + for block in self.extract_blocks(entry): + yield self.source_tag_factory(entry.file, block) diff --git a/core/dbt/parser/seeds.py b/core/dbt/parser/seeds.py index 7da8ace3b46..83027584bd5 100644 --- a/core/dbt/parser/seeds.py +++ b/core/dbt/parser/seeds.py @@ -1,65 +1,32 @@ -import os -from typing import Dict, Any - -import dbt.flags -import dbt.clients.agate_helper -import dbt.clients.system -import dbt.context.parser -import dbt.contracts.project -import dbt.exceptions - -from dbt.node_types import NodeType -from dbt.logger import GLOBAL_LOGGER as logger +from dbt.contracts.graph.manifest import SourceFile, FilePath from dbt.contracts.graph.parsed import ParsedSeedNode -from dbt.contracts.graph.unparsed import UnparsedNode -from dbt.parser.base import MacrosKnownParser +from dbt.node_types import NodeType +from dbt.source_config import SourceConfig +from dbt.parser.base import SimpleSQLParser +from dbt.parser.search import FileBlock, FilesystemSearcher -class SeedParser(MacrosKnownParser): - @classmethod - def parse_seed_file(cls, file_match, root_dir, package_name): - """Parse the given seed file, returning an UnparsedNode and the agate - table. - """ - abspath = file_match['absolute_path'] - logger.debug("Parsing {}".format(abspath)) - table_name = os.path.basename(abspath)[:-4] - node = UnparsedNode( - path=file_match['relative_path'], - name=table_name, - root_path=root_dir, - resource_type=NodeType.Seed, - # Give this raw_sql so it conforms to the node spec, - # use dummy text so it doesn't look like an empty node - raw_sql='-- csv --', - package_name=package_name, - original_file_path=os.path.join(file_match.get('searched_path'), - file_match.get('relative_path')), +class SeedParser(SimpleSQLParser[ParsedSeedNode]): + def get_paths(self): + return FilesystemSearcher( + self.project, self.project.data_paths, '.csv' ) - return node - - def load_and_parse(self, package_name, root_dir, relative_dirs, tags=None): - """Load and parse seed files in a list of directories. Returns a dict - that maps unique ids onto ParsedNodes""" - extension = "[!.#~]*.csv" + def parse_from_dict(self, dct, validate=True) -> ParsedSeedNode: + return ParsedSeedNode.from_dict(dct, validate=validate) - file_matches = dbt.clients.system.find_matching( - root_dir, - relative_dirs, - extension) + @property + def resource_type(self) -> NodeType: + return NodeType.Seed - result = {} - for file_match in file_matches: - node = self.parse_seed_file(file_match, root_dir, package_name) - node_path = self.get_path(NodeType.Seed, package_name, node.name) - parsed = self.parse_node(node, node_path, - self.all_projects.get(package_name), - tags=tags) - result[node_path] = parsed + @classmethod + def get_compiled_path(cls, block: FileBlock): + return block.path.relative_path - return result + def render_with_context( + self, parsed_node: ParsedSeedNode, config: SourceConfig + ) -> None: + """Seeds don't need to do any rendering.""" - def parse_from_dict(self, parsed_dict: Dict[str, Any]) -> ParsedSeedNode: - """Given a dictionary, return the parsed entity for this parser""" - return ParsedSeedNode.from_dict(parsed_dict) + def load_file(self, match: FilePath) -> SourceFile: + return SourceFile.seed(match) diff --git a/core/dbt/parser/snapshots.py b/core/dbt/parser/snapshots.py index ddcadc34d5b..c9aebd99327 100644 --- a/core/dbt/parser/snapshots.py +++ b/core/dbt/parser/snapshots.py @@ -1,92 +1,87 @@ - -from dbt.contracts.graph.parsed import ParsedSnapshotNode, \ - IntermediateSnapshotNode -from dbt.exceptions import CompilationException, validator_error_message -from dbt.node_types import NodeType -from dbt.parser.base_sql import BaseSqlParser, SQLParseResult -import dbt.clients.jinja -import dbt.utils +import os +from typing import List from hologram import ValidationError +from dbt.contracts.graph.parsed import ( + IntermediateSnapshotNode, ParsedSnapshotNode +) +from dbt.exceptions import ( + CompilationException, validator_error_message +) +from dbt.node_types import NodeType +from dbt.parser.base import SQLParser +from dbt.parser.search import ( + FilesystemSearcher, BlockContents, BlockSearcher, FileBlock +) +from dbt.utils import split_path -def set_snapshot_attributes(node): - if node.config.target_database: - node.database = node.config.target_database - if node.config.target_schema: - node.schema = node.config.target_schema - return node +class SnapshotParser( + SQLParser[IntermediateSnapshotNode, ParsedSnapshotNode] +): + def get_paths(self): + return FilesystemSearcher( + self.project, self.project.snapshot_paths, '.sql' + ) + def parse_from_dict(self, dct, validate=True) -> IntermediateSnapshotNode: + return IntermediateSnapshotNode.from_dict(dct, validate=validate) -class SnapshotParser(BaseSqlParser): - def parse_snapshots_from_file(self, file_node, tags=None): - # the file node has a 'raw_sql' field that contains the jinja data with - # (we hope!) `snapshot` blocks - try: - blocks = dbt.clients.jinja.extract_toplevel_blocks( - file_node['raw_sql'], - allowed_blocks={'snapshot'}, - collect_raw_data=False - ) - except CompilationException as exc: - if exc.node is None: - exc.node = file_node - raise - for block in blocks: - name = block.block_name - raw_sql = block.contents - updates = { - 'raw_sql': raw_sql, - 'name': name, - } - yield dbt.utils.deep_merge(file_node, updates) - - @classmethod - def get_compiled_path(cls, name, relative_path): - return relative_path + @property + def resource_type(self) -> NodeType: + return NodeType.Snapshot @classmethod - def get_fqn(cls, node, package_project_config, extra=[]): - parts = dbt.utils.split_path(node.path) - fqn = [package_project_config.project_name] - fqn.extend(parts[:-1]) - fqn.extend(extra) - fqn.append(node.name) - - return fqn - - def parse_from_dict(self, parsed_dict) -> IntermediateSnapshotNode: - return IntermediateSnapshotNode.from_dict(parsed_dict) - - @staticmethod - def validate_snapshots(node): - if node.resource_type == NodeType.Snapshot: - try: - parsed_node = ParsedSnapshotNode.from_dict(node.to_dict()) - return set_snapshot_attributes(parsed_node) - - except ValidationError as exc: - raise CompilationException(validator_error_message(exc), node) - else: - return node + def get_compiled_path(cls, block: FileBlock): + return block.path.relative_path + + def set_snapshot_attributes(self, node): + # use the target_database setting if we got it, otherwise the + # `database` value of the node (ultimately sourced from the `database` + # config value), and if that is not set, use the database defined in + # the adapter's credentials. + if node.config.target_database: + node.database = node.config.target_database + elif not node.database: + node.database = self.root_project.credentials.database + + # the target schema must be set if we got here, so overwrite the node's + # schema + node.schema = node.config.target_schema - def parse_sql_nodes(self, nodes, tags=None): - if tags is None: - tags = [] + return node - results = SQLParseResult() + def get_fqn(self, path: str, name: str) -> List[str]: + """Get the FQN for the node. This impacts node selection and config + application. - # in snapshots, we have stuff in blocks. - for file_node in nodes: - snapshot_nodes = list( - self.parse_snapshots_from_file(file_node, tags=tags) - ) - found = super().parse_sql_nodes(nodes=snapshot_nodes, tags=tags) - # Our snapshots are all stored as IntermediateSnapshotNodes, so - # convert them to their final form - found.parsed = {k: self.validate_snapshots(v) for - k, v in found.parsed.items()} + On snapshots, the fqn includes the filename. + """ + no_ext = os.path.splitext(path)[0] + fqn = [self.project.project_name] + fqn.extend(split_path(no_ext)) + fqn.append(name) + return fqn - results.update(found) - return results + def transform(self, node: IntermediateSnapshotNode) -> ParsedSnapshotNode: + try: + parsed_node = ParsedSnapshotNode.from_dict(node.to_dict()) + self.set_snapshot_attributes(parsed_node) + return parsed_node + + except ValidationError as exc: + raise CompilationException(validator_error_message(exc), node) + + def parse_file(self, file_block: FileBlock) -> None: + blocks = BlockSearcher( + source=[file_block], + allowed_blocks={'snapshot'}, + source_tag_factory=BlockContents, + ) + for block in blocks: + self.parse_node(block) + # in case there are no snapshots declared, we still want to mark this + # file as seen. But after we've finished, because we don't want to add + # files with syntax errors + self.results.get_file(file_block.file) diff --git a/core/dbt/parser/util.py b/core/dbt/parser/util.py index aeacec495b9..74801898f96 100644 --- a/core/dbt/parser/util.py +++ b/core/dbt/parser/util.py @@ -1,17 +1,17 @@ +from typing import Optional import dbt.exceptions import dbt.utils from dbt.node_types import NodeType from dbt.contracts.graph.parsed import ColumnInfo +from dbt.config import Project -def docs(node, manifest, config, column_name=None): +def docs(node, manifest, current_project: str, column_name=None): """Return a function that will process `doc()` references in jinja, look them up in the manifest, and return the appropriate block contents. """ - current_project = config.project_name - - def do_docs(*args): + def do_docs(*args: str): if len(args) == 1: doc_package_name = None doc_name = args[0] @@ -38,8 +38,11 @@ class ParserUtils: DISABLED = object() @classmethod - def resolve_source(cls, manifest, target_source_name, - target_table_name, current_project, node_package): + def resolve_source( + cls, manifest, target_source_name: Optional[str], + target_table_name: Optional[str], current_project: str, + node_package: str + ): candidate_targets = [current_project, node_package, None] target_source = None for candidate in candidate_targets: @@ -54,8 +57,11 @@ def resolve_source(cls, manifest, target_source_name, return None @classmethod - def resolve_ref(cls, manifest, target_model_name, target_model_package, - current_project, node_package): + def resolve_ref( + cls, manifest, target_model_name: Optional[str], + target_model_package: Optional[str], current_project: str, + node_package: str + ): if target_model_package is not None: return manifest.find_refable_by_name( target_model_name, @@ -89,8 +95,10 @@ def resolve_ref(cls, manifest, target_model_name, target_model_package, return None @classmethod - def resolve_doc(cls, manifest, target_doc_name, target_doc_package, - current_project, node_package): + def resolve_doc( + cls, manifest, target_doc_name: str, target_doc_package: Optional[str], + current_project: str, node_package: str + ): """Resolve the given documentation. This follows the same algorithm as resolve_ref except the is_enabled checks are unnecessary as docs are always enabled. @@ -122,7 +130,7 @@ def _get_node_column(cls, node, column_name): return column @classmethod - def process_docs_for_node(cls, manifest, current_project, node): + def process_docs_for_node(cls, manifest, current_project: str, node): for docref in node.docrefs: column_name = docref.column_name @@ -141,7 +149,7 @@ def process_docs_for_node(cls, manifest, current_project, node): obj.description = dbt.clients.jinja.get_rendered(raw, context) @classmethod - def process_docs_for_source(cls, manifest, current_project, source): + def process_docs_for_source(cls, manifest, current_project: str, source): context = { 'doc': docs(source, manifest, current_project), } @@ -154,13 +162,13 @@ def process_docs_for_source(cls, manifest, current_project, source): source.description = table_description source.source_description = source_description - for column_name, column_def in source.columns.items(): - column_desc = column_def.description or '' + for column in source.columns.values(): + column_desc = column.description column_desc = dbt.clients.jinja.get_rendered(column_desc, context) - column_def.description = column_desc + column.description = column_desc @classmethod - def process_docs(cls, manifest, current_project): + def process_docs(cls, manifest, current_project: str): for node in manifest.nodes.values(): if node.resource_type == NodeType.Source: cls.process_docs_for_source(manifest, current_project, node) @@ -169,7 +177,7 @@ def process_docs(cls, manifest, current_project): return manifest @classmethod - def process_refs_for_node(cls, manifest, current_project, node): + def process_refs_for_node(cls, manifest, current_project: str, node): """Given a manifest and a node in that manifest, process its refs""" for ref in node.refs: target_model = None @@ -202,16 +210,21 @@ def process_refs_for_node(cls, manifest, current_project, node): target_model_id = target_model.unique_id node.depends_on.nodes.append(target_model_id) - manifest.nodes[node.unique_id] = node + # TODO: I think this is extraneous, node should already be the same + # as manifest.nodes[node.unique_id] (we're mutating node here, not + # making a new one) + manifest.update_node(node) @classmethod - def process_refs(cls, manifest, current_project): - for node in manifest.nodes.values(): + def process_refs(cls, manifest, current_project: str): + # process_refs_for_node will mutate this + all_nodes = list(manifest.nodes.values()) + for node in all_nodes: cls.process_refs_for_node(manifest, current_project, node) return manifest @classmethod - def process_sources_for_node(cls, manifest, current_project, node): + def process_sources_for_node(cls, manifest, current_project: str, node): target_source = None for source_name, table_name in node.sources: target_source = cls.resolve_source( @@ -231,16 +244,17 @@ def process_sources_for_node(cls, manifest, current_project, node): continue target_source_id = target_source.unique_id node.depends_on.nodes.append(target_source_id) - manifest.nodes[node.unique_id] = node + manifest.update_node(node) @classmethod - def process_sources(cls, manifest, current_project): - for node in manifest.nodes.values(): + def process_sources(cls, manifest, current_project: str): + all_nodes = list(manifest.nodes.values()) + for node in all_nodes: cls.process_sources_for_node(manifest, current_project, node) return manifest @classmethod - def add_new_refs(cls, manifest, current_project, node, macros): + def add_new_refs(cls, manifest, current_project: Project, node, macros): """Given a new node that is not in the manifest, copy the manifest and insert the new node into it as if it were part of regular ref processing @@ -249,14 +263,10 @@ def add_new_refs(cls, manifest, current_project, node, macros): # it's ok for macros to silently override a local project macro name manifest.macros.update(macros) - if node.unique_id in manifest.nodes: - # this should be _impossible_ due to the fact that rpc calls get - # a unique ID that starts with 'rpc'! - raise dbt.exceptions.raise_duplicate_resource_name( - manifest.nodes[node.unique_id], node - ) - manifest.nodes[node.unique_id] = node - cls.process_sources_for_node(manifest, current_project, node) - cls.process_refs_for_node(manifest, current_project, node) - cls.process_docs_for_node(manifest, current_project, node) + manifest.add_nodes({node.unique_id: node}) + cls.process_sources_for_node( + manifest, current_project.project_name, node + ) + cls.process_refs_for_node(manifest, current_project.project_name, node) + cls.process_docs_for_node(manifest, current_project.project_name, node) return manifest diff --git a/core/dbt/py.typed b/core/dbt/py.typed new file mode 100644 index 00000000000..19e6b2eced6 --- /dev/null +++ b/core/dbt/py.typed @@ -0,0 +1 @@ +# dummy file, our types are defined inline diff --git a/core/dbt/parser/source_config.py b/core/dbt/source_config.py similarity index 100% rename from core/dbt/parser/source_config.py rename to core/dbt/source_config.py diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index 6cf3fd8206c..664c006abf0 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -1,5 +1,6 @@ -from abc import ABCMeta, abstractmethod import os +from abc import ABCMeta, abstractmethod +from typing import Type, Union from dbt.config import RuntimeConfig, Project from dbt.config.profile import read_profile, PROFILES_DIR @@ -37,7 +38,7 @@ def read_profiles(profiles_dir=None): class BaseTask(metaclass=ABCMeta): - ConfigType = NoneConfig + ConfigType: Union[Type[NoneConfig], Type[Project]] = NoneConfig def __init__(self, args, config): self.args = args diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index 90dcbbef79b..03f5316061e 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -1,4 +1,3 @@ -import os import signal import threading from typing import Union, List, Dict, Any @@ -6,11 +5,10 @@ from dbt.adapters.factory import get_adapter from dbt.clients.jinja import extract_toplevel_blocks from dbt.compilation import compile_manifest -from dbt.loader import load_all_projects from dbt.node_runners import CompileRunner, RPCCompileRunner from dbt.node_types import NodeType -from dbt.parser.analysis import RPCCallParser -from dbt.parser.macros import MacroParser +from dbt.parser.results import ParseResult +from dbt.parser.rpc import RPCCallParser, RPCMacroParser from dbt.parser.util import ParserUtils import dbt.ui.printer from dbt.logger import RPC_LOGGER as rpc_logger @@ -19,7 +17,6 @@ class CompileTask(GraphRunnableTask): - def raise_on_first_error(self): return True @@ -72,39 +69,23 @@ def _extract_request_data(self, data): return sql, macros def _get_exec_node(self, name, sql, macros): - request_path = os.path.join(self.config.target_path, 'rpc', name) - all_projects = load_all_projects(self.config) + results = ParseResult.rpc() macro_overrides = {} sql, macros = self._extract_request_data(sql) if macros: - macro_parser = MacroParser(self.config, all_projects) - macro_overrides.update(macro_parser.parse_macro_file( - macro_file_path='from remote system', - macro_file_contents=macros, - root_path=request_path, - package_name=self.config.project_name, - resource_type=NodeType.Macro - )) + macro_parser = RPCMacroParser(results, self.config) + for node in macro_parser.parse_remote(macros): + macro_overrides[node.unique_id] = node self._base_manifest.macros.update(macro_overrides) rpc_parser = RPCCallParser( - self.config, - all_projects=all_projects, - macro_manifest=self._base_manifest + results=results, + project=self.config, + root_project=self.config, + macro_manifest=self._base_manifest, ) - - node_dict = { - 'name': name, - 'root_path': request_path, - 'resource_type': NodeType.RPCCall, - 'path': name + '.sql', - 'original_file_path': 'from remote system', - 'package_name': self.config.project_name, - 'raw_sql': sql, - } - - unique_id, node = rpc_parser.parse_sql_node(node_dict) + node = rpc_parser.parse_remote(sql, name) self.manifest = ParserUtils.add_new_refs( manifest=self._base_manifest, current_project=self.config, diff --git a/core/dbt/task/list.py b/core/dbt/task/list.py index 6e9d269a813..8616df52c4c 100644 --- a/core/dbt/task/list.py +++ b/core/dbt/task/list.py @@ -1,5 +1,3 @@ -from __future__ import print_function - import json from dbt.task.runnable import GraphRunnableTask, ManifestTask diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index deb4faf7857..39c504655a3 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -15,6 +15,7 @@ from dbt.loader import GraphLoader import dbt.exceptions +import dbt.flags import dbt.ui.printer import dbt.utils @@ -26,12 +27,16 @@ def load_manifest(config): # performance trick: if the adapter has a manifest loaded, use that to - # avoid parsing internal macros twice. - internal_manifest = get_adapter(config).check_internal_manifest() - manifest = GraphLoader.load_all(config, - internal_manifest=internal_manifest) + # avoid parsing internal macros twice. Also, when loading the adapter's + # manifest, load the internal manifest to avoid running the graph laoder + # twice. + adapter = get_adapter(config) - manifest.write(os.path.join(config.target_path, MANIFEST_FILE_NAME)) + internal = adapter.load_internal_manifest() + manifest = GraphLoader.load_all(config, internal_manifest=internal) + + if dbt.flags.WRITE_JSON: + manifest.write(os.path.join(config.target_path, MANIFEST_FILE_NAME)) return manifest @@ -182,15 +187,14 @@ def _handle_result(self, result): self.node_results.append(result) node = result.node - node_id = node.unique_id - self.manifest.nodes[node_id] = node + self.manifest.update_node(node) if result.error is not None: if is_ephemeral: cause = result else: cause = None - self._mark_dependent_errors(node_id, result, cause) + self._mark_dependent_errors(node.unique_id, result, cause) def execute_nodes(self): num_threads = self.config.threads @@ -288,7 +292,8 @@ def run(self): selected_uids = frozenset(n.unique_id for n in self._flattened_nodes) result = self.execute_with_hooks(selected_uids) - result.write(self.result_path()) + if dbt.flags.WRITE_JSON: + result.write(self.result_path()) self.task_end_messages(result.results) return result.results @@ -342,11 +347,11 @@ def task_end_messages(self, results): class RemoteCallable: - METHOD_NAME = None + METHOD_NAME: Optional[str] = None is_async = False @abstractmethod - def handle_request(self, **kwargs): + def handle_request(self): raise dbt.exceptions.NotImplementedException( 'from_kwargs not implemented' ) diff --git a/core/dbt/tracking.py b/core/dbt/tracking.py index fcda0ac8d33..ca98c3798c4 100644 --- a/core/dbt/tracking.py +++ b/core/dbt/tracking.py @@ -49,7 +49,7 @@ def http_get(self, payload): if self.is_good_status_code(r.status_code): sp_logger.info(msg) else: - sp_logger.warn(msg) + sp_logger.warning(msg) return r diff --git a/core/dbt/types.py b/core/dbt/types.py deleted file mode 100644 index 82af2e9652d..00000000000 --- a/core/dbt/types.py +++ /dev/null @@ -1,17 +0,0 @@ -from hologram import FieldEncoder, JsonSchemaMixin -from typing import Type, NewType - - -def NewRangedInteger(name: str, minimum: int, maximum: int) -> Type: - ranged = NewType(name, int) - - class RangeEncoder(FieldEncoder): - @property - def json_schema(self): - return {'type': 'integer', 'minimum': minimum, 'maximum': maximum} - - JsonSchemaMixin.register_field_encoders({ranged: RangeEncoder()}) - return ranged - - -Port = NewRangedInteger('Port', minimum=0, maximum=65535) diff --git a/core/dbt/ui/colors.py b/core/dbt/ui/colors.py index 798005f10cf..49b607ba812 100644 --- a/core/dbt/ui/colors.py +++ b/core/dbt/ui/colors.py @@ -1,7 +1,8 @@ +from typing import Dict import colorama -COLORS = { +COLORS: Dict[str, str] = { 'red': colorama.Fore.RED, 'green': colorama.Fore.GREEN, 'yellow': colorama.Fore.YELLOW, diff --git a/core/dbt/ui/printer.py b/core/dbt/ui/printer.py index b447b3676e0..49833ed4d93 100644 --- a/core/dbt/ui/printer.py +++ b/core/dbt/ui/printer.py @@ -1,3 +1,4 @@ +from typing import Dict, Optional, Tuple from dbt.logger import GLOBAL_LOGGER as logger from dbt.utils import get_materialization @@ -30,34 +31,36 @@ def get_timestamp(): return time.strftime("%H:%M:%S") -def color(text, color_code): +def color(text: str, color_code: str): if USE_COLORS: return "{}{}{}".format(color_code, text, COLOR_RESET_ALL) else: return text -def green(text): +def green(text: str): return color(text, COLOR_FG_GREEN) -def yellow(text): +def yellow(text: str): return color(text, COLOR_FG_YELLOW) -def red(text): +def red(text: str): return color(text, COLOR_FG_RED) -def print_timestamped_line(msg, use_color=None): +def print_timestamped_line(msg: str, use_color: Optional[str] = None): if use_color is not None: msg = color(msg, use_color) logger.info("{} | {}".format(get_timestamp(), msg)) -def print_fancy_output_line(msg, status, index, total, execution_time=None, - truncate=False): +def print_fancy_output_line( + msg: str, status: str, index: Optional[int], total: Optional[int], + execution_time: Optional[float] = None, truncate: bool = False +) -> None: if index is None or total is None: progress = '' else: @@ -78,16 +81,14 @@ def print_fancy_output_line(msg, status, index, total, execution_time=None, status_time = " in {execution_time:0.2f}s".format( execution_time=execution_time) - status_txt = status - output = "{justified} [{status}{status_time}]".format( - justified=justified, status=status_txt, status_time=status_time) + justified=justified, status=status, status_time=status_time) logger.info(output) -def get_counts(flat_nodes): - counts = {} +def get_counts(flat_nodes) -> str: + counts: Dict[str, int] = {} for node in flat_nodes: t = node.resource_type @@ -105,34 +106,38 @@ def get_counts(flat_nodes): return stat_line -def print_start_line(description, index, total): +def print_start_line(description: str, index: int, total: int) -> None: msg = "START {}".format(description) print_fancy_output_line(msg, 'RUN', index, total) -def print_hook_start_line(statement, index, total): +def print_hook_start_line(statement: str, index: int, total: int) -> None: msg = 'START hook: {}'.format(statement) print_fancy_output_line(msg, 'RUN', index, total, truncate=True) -def print_hook_end_line(statement, status, index, total, execution_time): +def print_hook_end_line( + statement: str, status: str, index: int, total: int, execution_time: float +) -> None: msg = 'OK hook: {}'.format(statement) # hooks don't fail into this path, so always green print_fancy_output_line(msg, green(status), index, total, execution_time=execution_time, truncate=True) -def print_skip_line(model, schema, relation, index, num_models): +def print_skip_line( + model, schema: str, relation: str, index: int, num_models: int +) -> None: msg = 'SKIP relation {}.{}'.format(schema, relation) print_fancy_output_line(msg, yellow('SKIP'), index, num_models) -def print_cancel_line(model): +def print_cancel_line(model) -> None: msg = 'CANCEL query {}'.format(model) print_fancy_output_line(msg, red('CANCEL'), index=None, total=None) -def get_printable_result(result, success, error): +def get_printable_result(result, success: str, error: str) -> Tuple[str, str]: if result.error is not None: info = 'ERROR {}'.format(error) status = red(result.status) @@ -143,7 +148,9 @@ def get_printable_result(result, success, error): return info, status -def print_test_result_line(result, schema_name, index, total): +def print_test_result_line( + result, schema_name, index: int, total: int +) -> None: model = result.node if result.error is not None: @@ -172,7 +179,9 @@ def print_test_result_line(result, schema_name, index, total): result.execution_time) -def print_model_result_line(result, description, index, total): +def print_model_result_line( + result, description: str, index: int, total: int +) -> None: info, status = get_printable_result(result, 'created', 'creating') print_fancy_output_line( @@ -183,7 +192,7 @@ def print_model_result_line(result, description, index, total): result.execution_time) -def print_snapshot_result_line(result, index, total): +def print_snapshot_result_line(result, index: int, total: int): model = result.node info, status = get_printable_result(result, 'snapshotted', 'snapshotting') @@ -199,7 +208,7 @@ def print_snapshot_result_line(result, index, total): result.execution_time) -def print_seed_result_line(result, schema_name, index, total): +def print_seed_result_line(result, schema_name: str, index: int, total: int): model = result.node info, status = get_printable_result(result, 'loaded', 'loading') @@ -215,7 +224,7 @@ def print_seed_result_line(result, schema_name, index, total): result.execution_time) -def print_freshness_result_line(result, index, total): +def print_freshness_result_line(result, index: int, total: int) -> None: if result.error: info = 'ERROR' color = red @@ -251,7 +260,7 @@ def print_freshness_result_line(result, index, total): ) -def interpret_run_result(result): +def interpret_run_result(result) -> str: if result.error is not None or result.fail: return 'error' elif result.skipped: @@ -262,7 +271,7 @@ def interpret_run_result(result): return 'pass' -def print_run_status_line(results): +def print_run_status_line(results) -> None: stats = { 'error': 0, 'skip': 0, @@ -280,7 +289,9 @@ def print_run_status_line(results): logger.info(stats_line.format(**stats)) -def print_run_result_error(result, newline=True, is_warning=False): +def print_run_result_error( + result, newline: bool = True, is_warning: bool = False +) -> None: if newline: logger.info("") @@ -314,15 +325,18 @@ def print_run_result_error(result, newline=True, is_warning=False): logger.info(line) -def print_skip_caused_by_error(model, schema, relation, index, num_models, - result): +def print_skip_caused_by_error( + model, schema: str, relation: str, index: int, num_models: int, result +) -> None: msg = ('SKIP relation {}.{} due to ephemeral model error' .format(schema, relation)) print_fancy_output_line(msg, red('ERROR SKIP'), index, num_models) print_run_result_error(result, newline=False) -def print_end_of_run_summary(num_errors, num_warnings, early_exit=False): +def print_end_of_run_summary( + num_errors: int, num_warnings: int, early_exit: bool = False +) -> None: error_plural = dbt.utils.pluralize(num_errors, 'error') warn_plural = dbt.utils.pluralize(num_warnings, 'warning') if early_exit: @@ -339,7 +353,7 @@ def print_end_of_run_summary(num_errors, num_warnings, early_exit=False): logger.info('{}'.format(message)) -def print_run_end_messages(results, early_exit=False): +def print_run_end_messages(results, early_exit: bool = False) -> None: errors = [r for r in results if r.error is not None or r.fail] warnings = [r for r in results if r.warn] print_end_of_run_summary(len(errors), len(warnings), early_exit) diff --git a/core/dbt/utils.py b/core/dbt/utils.py index cf7df29d429..27eb46572b5 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -8,6 +8,7 @@ import json import os from enum import Enum +from typing import Tuple, Type, Any import dbt.exceptions @@ -15,8 +16,9 @@ from dbt.node_types import NodeType from dbt.clients import yaml_helper +DECIMALS: Tuple[Type[Any], ...] try: - import cdecimal + import cdecimal # typing: ignore except ImportError: DECIMALS = (decimal.Decimal,) else: @@ -91,6 +93,9 @@ def id_matches(unique_id, target_name, target_package, nodetypes, model): dbt.exceptions.raise_compiler_error(msg, model) resource_type, package_name, node_name = node_parts + if resource_type not in nodetypes: + return False + if node_type == NodeType.Source.value: if node_name.count('.') != 1: msg = "{} names must contain exactly 1 '.' character"\ @@ -101,9 +106,6 @@ def id_matches(unique_id, target_name, target_package, nodetypes, model): msg = "{} names cannot contain '.' characters".format(node_type) dbt.exceptions.raise_compiler_error(msg, model) - if resource_type not in nodetypes: - return False - if target_name != node_name: return False @@ -290,22 +292,6 @@ def __init__(self, *args, **kwargs): self.__dict__ = self -def to_unicode(s, encoding): - try: - unicode - return unicode(s, encoding) - except NameError: - return s - - -def to_string(s): - try: - unicode - return s.encode('utf-8') - except NameError: - return s - - def get_materialization(node): return node.config.materialized diff --git a/core/dbt/version.py b/core/dbt/version.py index 6a5d1934869..fbe936f9d65 100644 --- a/core/dbt/version.py +++ b/core/dbt/version.py @@ -56,5 +56,5 @@ def get_version_information(): .format(version_msg)) -__version__ = '0.14.0' +__version__ = '0.15.0a1' installed = get_installed_version() diff --git a/core/setup.py b/core/setup.py index bd5e657449e..f4dec25f917 100644 --- a/core/setup.py +++ b/core/setup.py @@ -9,7 +9,7 @@ def read(fname): package_name = "dbt-core" -package_version = "0.14.0" +package_version = "0.15.0a1" description = """dbt (data build tool) is a command line tool that helps \ analysts and engineers transform data in their warehouse more effectively""" @@ -31,6 +31,7 @@ def read(fname): 'include/global_project/macros/*.sql', 'include/global_project/macros/**/*.sql', 'include/global_project/macros/**/**/*.sql', + 'py.typed', ] }, test_suite='test', diff --git a/dev_requirements.txt b/dev_requirements.txt index 7e39d83b66e..a5c31a1a5d4 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -7,3 +7,4 @@ tox==2.5.0 ipdb pytest-xdist>=1.28.0,<2 flaky>=3.5.3,<4 +mypy==0.720 diff --git a/plugins/bigquery/dbt/__init__.py b/plugins/bigquery/dbt/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/bigquery/dbt/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/bigquery/dbt/adapters/__init__.py b/plugins/bigquery/dbt/adapters/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/bigquery/dbt/adapters/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/bigquery/dbt/adapters/bigquery/connections.py b/plugins/bigquery/dbt/adapters/bigquery/connections.py index abfe3e38239..a737855195f 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/connections.py +++ b/plugins/bigquery/dbt/adapters/bigquery/connections.py @@ -26,8 +26,6 @@ class BigQueryConnectionMethod(StrEnum): @dataclass class BigQueryCredentials(Credentials): method: BigQueryConnectionMethod - database: str - schema: str keyfile: Optional[str] = None keyfile_json: Optional[Dict[str, Any]] = None timeout_seconds: Optional[int] = 300 diff --git a/plugins/bigquery/dbt/adapters/bigquery/impl.py b/plugins/bigquery/dbt/adapters/bigquery/impl.py index e18c3f6661e..ab9eefef3f1 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/impl.py +++ b/plugins/bigquery/dbt/adapters/bigquery/impl.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import copy import dbt.deprecations diff --git a/plugins/bigquery/dbt/include/__init__.py b/plugins/bigquery/dbt/include/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/bigquery/dbt/include/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/bigquery/setup.py b/plugins/bigquery/setup.py index 531a66e49bb..eeb2ff40263 100644 --- a/plugins/bigquery/setup.py +++ b/plugins/bigquery/setup.py @@ -4,7 +4,7 @@ import os package_name = "dbt-bigquery" -package_version = "0.14.0" +package_version = "0.15.0a1" description = """The bigquery adapter plugin for dbt (data build tool)""" this_directory = os.path.abspath(os.path.dirname(__file__)) diff --git a/plugins/postgres/dbt/__init__.py b/plugins/postgres/dbt/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/postgres/dbt/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/postgres/dbt/adapters/__init__.py b/plugins/postgres/dbt/adapters/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/postgres/dbt/adapters/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/postgres/dbt/adapters/postgres/connections.py b/plugins/postgres/dbt/adapters/postgres/connections.py index 1915df00557..546016e0717 100644 --- a/plugins/postgres/dbt/adapters/postgres/connections.py +++ b/plugins/postgres/dbt/adapters/postgres/connections.py @@ -7,19 +7,17 @@ from dbt.adapters.sql import SQLConnectionManager from dbt.logger import GLOBAL_LOGGER as logger -from dbt.types import Port +from dbt.helper_types import Port from dataclasses import dataclass from typing import Optional @dataclass class PostgresCredentials(Credentials): - database: str host: str user: str password: str port: Port - schema: str search_path: Optional[str] keepalives_idle: Optional[int] = 0 # 0 means to use the default value diff --git a/plugins/postgres/dbt/include/__init__.py b/plugins/postgres/dbt/include/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/postgres/dbt/include/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/postgres/dbt/include/postgres/macros/adapters.sql b/plugins/postgres/dbt/include/postgres/macros/adapters.sql index c087a17667c..b776da9fab0 100644 --- a/plugins/postgres/dbt/include/postgres/macros/adapters.sql +++ b/plugins/postgres/dbt/include/postgres/macros/adapters.sql @@ -78,7 +78,7 @@ {% endmacro %} {% macro postgres__check_schema_exists(information_schema, schema) -%} - {% if database -%} + {% if information_schema.database -%} {{ adapter.verify_database(information_schema.database) }} {%- endif -%} {% call statement('check_schema_exists', fetch_result=True, auto_begin=False) %} diff --git a/plugins/postgres/setup.py b/plugins/postgres/setup.py index 41d217ead19..f98cba072ca 100644 --- a/plugins/postgres/setup.py +++ b/plugins/postgres/setup.py @@ -4,7 +4,7 @@ import os package_name = "dbt-postgres" -package_version = "0.14.0" +package_version = "0.15.0a1" description = """The postgres adpter plugin for dbt (data build tool)""" this_directory = os.path.abspath(os.path.dirname(__file__)) diff --git a/plugins/redshift/dbt/__init__.py b/plugins/redshift/dbt/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/redshift/dbt/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/redshift/dbt/adapters/__init__.py b/plugins/redshift/dbt/adapters/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/redshift/dbt/adapters/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/redshift/dbt/adapters/redshift/connections.py b/plugins/redshift/dbt/adapters/redshift/connections.py index 4b93061bc41..f06479e5af1 100644 --- a/plugins/redshift/dbt/adapters/redshift/connections.py +++ b/plugins/redshift/dbt/adapters/redshift/connections.py @@ -1,5 +1,6 @@ -from contextlib import contextmanager import multiprocessing +from contextlib import contextmanager +from typing import NewType from dbt.adapters.postgres import PostgresConnectionManager from dbt.adapters.postgres import PostgresCredentials @@ -8,7 +9,7 @@ import boto3 -from dbt.types import NewRangedInteger +from hologram import FieldEncoder, JsonSchemaMixin from hologram.helpers import StrEnum from dataclasses import dataclass, field @@ -17,7 +18,16 @@ drop_lock = multiprocessing.Lock() -IAMDuration = NewRangedInteger('IAMDuration', minimum=900, maximum=3600) +IAMDuration = NewType('IAMDuration', int) + + +class IAMDurationEncoder(FieldEncoder): + @property + def json_schema(self): + return {'type': 'integer', 'minimum': 0, 'maximum': 65535} + + +JsonSchemaMixin.register_field_encoders({IAMDuration: IAMDurationEncoder()}) class RedshiftConnectionMethod(StrEnum): diff --git a/plugins/redshift/dbt/include/__init__.py b/plugins/redshift/dbt/include/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/redshift/dbt/include/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/redshift/setup.py b/plugins/redshift/setup.py index 2e5b37e7ba7..48bfb096811 100644 --- a/plugins/redshift/setup.py +++ b/plugins/redshift/setup.py @@ -4,7 +4,7 @@ import os package_name = "dbt-redshift" -package_version = "0.14.0" +package_version = "0.15.0a1" description = """The redshift adapter plugin for dbt (data build tool)""" this_directory = os.path.abspath(os.path.dirname(__file__)) diff --git a/plugins/snowflake/dbt/__init__.py b/plugins/snowflake/dbt/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/snowflake/dbt/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/snowflake/dbt/adapters/__init__.py b/plugins/snowflake/dbt/adapters/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/snowflake/dbt/adapters/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index e0a37c1c755..17829fda2db 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -20,8 +20,6 @@ class SnowflakeCredentials(Credentials): account: str user: str - database: str - schema: str warehouse: Optional[str] role: Optional[str] password: Optional[str] diff --git a/plugins/snowflake/dbt/adapters/snowflake/impl.py b/plugins/snowflake/dbt/adapters/snowflake/impl.py index ebcad66bef5..7797db2aac9 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/impl.py +++ b/plugins/snowflake/dbt/adapters/snowflake/impl.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from dbt.adapters.sql import SQLAdapter from dbt.adapters.snowflake import SnowflakeConnectionManager from dbt.adapters.snowflake import SnowflakeRelation diff --git a/plugins/snowflake/dbt/include/__init__.py b/plugins/snowflake/dbt/include/__init__.py deleted file mode 100644 index 69e3be50dac..00000000000 --- a/plugins/snowflake/dbt/include/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/plugins/snowflake/setup.py b/plugins/snowflake/setup.py index c4e46fb4147..b698520d467 100644 --- a/plugins/snowflake/setup.py +++ b/plugins/snowflake/setup.py @@ -4,7 +4,7 @@ import os package_name = "dbt-snowflake" -package_version = "0.14.0" +package_version = "0.15.0a1" description = """The snowflake adapter plugin for dbt (data build tool)""" this_directory = os.path.abspath(os.path.dirname(__file__)) diff --git a/setup.py b/setup.py index fac6523e884..ce42c5b9515 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ package_name = "dbt" -package_version = "0.14.0" +package_version = "0.15.0a1" description = """With dbt, data analysts and engineers can build analytics \ the way engineers build applications.""" diff --git a/test/integration/004_simple_snapshot_test/test-snapshots-invalid/snapshot.sql b/test/integration/004_simple_snapshot_test/test-snapshots-invalid/snapshot.sql index 88c0e72f0e4..6d0561f6267 100644 --- a/test/integration/004_simple_snapshot_test/test-snapshots-invalid/snapshot.sql +++ b/test/integration/004_simple_snapshot_test/test-snapshots-invalid/snapshot.sql @@ -1,4 +1,6 @@ -{% snapshot no_target_database %} +{# make sure to never name this anything with `target_schema` in the name, or the test will be invalid! #} +{% snapshot missing_field_target_underscore_schema %} + {# missing the mandatory target_schema parameter #} {{ config( unique_key='id || ' ~ "'-'" ~ ' || first_name', diff --git a/test/integration/004_simple_snapshot_test/test-snapshots-select/snapshot.sql b/test/integration/004_simple_snapshot_test/test-snapshots-select/snapshot.sql index e93581df83d..06245f36f70 100644 --- a/test/integration/004_simple_snapshot_test/test-snapshots-select/snapshot.sql +++ b/test/integration/004_simple_snapshot_test/test-snapshots-select/snapshot.sql @@ -30,10 +30,9 @@ {% snapshot snapshot_kelly %} - + {# This has no target_database set, which is allowed! #} {{ config( - target_database=var('target_database', database), target_schema=schema, unique_key='id || ' ~ "'-'" ~ ' || first_name', strategy='timestamp', diff --git a/test/integration/029_docs_generate_tests/test_docs_generate.py b/test/integration/029_docs_generate_tests/test_docs_generate.py index 81e03039b38..a6c59ec40cb 100644 --- a/test/integration/029_docs_generate_tests/test_docs_generate.py +++ b/test/integration/029_docs_generate_tests/test_docs_generate.py @@ -1,11 +1,11 @@ -from __future__ import unicode_literals +import hashlib import json import os from datetime import datetime, timedelta from unittest.mock import ANY, patch from test.integration.base import DBTIntegrationTest, use_profile, AnyFloat, \ - AnyStringWith + AnyStringWith, normalize def _read_file(path): @@ -47,19 +47,6 @@ def _read_json(path): return json.load(fp) -def _normalize(path): - """On windows, neither is enough on its own: - - >>> normcase('C:\\documents/ALL CAPS/subdir\\..') - 'c:\\documents\\all caps\\subdir\\..' - >>> normpath('C:\\documents/ALL CAPS/subdir\\..') - 'C:\\documents\\ALL CAPS' - >>> normpath(normcase('C:\\documents/ALL CAPS/subdir\\..')) - 'c:\\documents\\all caps' - """ - return os.path.normcase(os.path.normpath(path)) - - def walk_files(path): for root, dirs, files in os.walk(path): for basename in files: @@ -79,7 +66,7 @@ def schema(self): @staticmethod def dir(path): - return _normalize(path) + return normalize(path) @property def models(self): @@ -122,8 +109,8 @@ def run_and_generate(self, extra=None, seed_count=1, model_count=1, alternate_db self.assertEqual(len(self.run_dbt(["seed", vars_arg])), seed_count) self.assertEqual(len(self.run_dbt(['run', vars_arg])), model_count) - os.remove(_normalize('target/manifest.json')) - os.remove(_normalize('target/run_results.json')) + os.remove(normalize('target/manifest.json')) + os.remove(normalize('target/run_results.json')) self.generate_start_time = datetime.utcnow() self.run_dbt(['docs', 'generate', vars_arg]) @@ -789,14 +776,14 @@ def verify_manifest_macros(self, manifest): self.assertTrue(len(macro['raw_sql']) > 10) without_sql = {k: v for k, v in macro.items() if k != 'raw_sql'} # Windows means we can't hard-code this. - helpers_path = _normalize('macros/materializations/helpers.sql') + helpers_path = normalize('macros/materializations/helpers.sql') self.assertEqual( without_sql, { 'path': helpers_path, 'original_file_path': helpers_path, 'package_name': 'dbt', - 'root_path': _normalize(os.path.join( + 'root_path': normalize(os.path.join( self.initial_dir, 'core', 'dbt','include', 'global_project' )), 'name': 'column_list', @@ -896,7 +883,7 @@ def expected_seeded_manifest(self, model_database=None): 'name': 'seed', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'resource_type': 'seed', - 'raw_sql': '-- csv --', + 'raw_sql': '', 'package_name': 'test', 'original_file_path': self.dir(os.path.join('seed', 'seed.csv')), @@ -938,7 +925,7 @@ def expected_seeded_manifest(self, model_database=None): 'original_file_path': schema_yml_path, 'package_name': 'test', 'patch_path': None, - 'path': _normalize('schema_test/not_null_model_id.sql'), + 'path': normalize('schema_test/not_null_model_id.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test_not_null(model=ref('model'), column_name='id') }}", 'refs': [['model']], 'resource_type': 'test', @@ -949,8 +936,8 @@ def expected_seeded_manifest(self, model_database=None): 'unique_id': 'test.test.not_null_model_id', 'docrefs': [], }, - 'test.test.nothing_model_': { - 'alias': 'nothing_model_', + 'test.test.test_nothing_model_': { + 'alias': 'test_nothing_model_', 'build_path': None, 'column_name': None, 'columns': {}, @@ -969,12 +956,12 @@ def expected_seeded_manifest(self, model_database=None): 'sources': [], 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', - 'fqn': ['test', 'schema_test', 'nothing_model_'], - 'name': 'nothing_model_', + 'fqn': ['test', 'schema_test', 'test_nothing_model_'], + 'name': 'test_nothing_model_', 'original_file_path': schema_yml_path, 'package_name': 'test', 'patch_path': None, - 'path': _normalize('schema_test/nothing_model_.sql'), + 'path': normalize('schema_test/test_nothing_model_.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test.test_nothing(model=ref('model'), ) }}", 'refs': [['model']], 'resource_type': 'test', @@ -982,7 +969,7 @@ def expected_seeded_manifest(self, model_database=None): 'schema': my_schema_name, 'database': self.default_database, 'tags': ['schema'], - 'unique_id': 'test.test.nothing_model_', + 'unique_id': 'test.test.test_nothing_model_', 'docrefs': [], }, 'test.test.unique_model_id': { @@ -1010,7 +997,7 @@ def expected_seeded_manifest(self, model_database=None): 'original_file_path': schema_yml_path, 'package_name': 'test', 'patch_path': None, - 'path': _normalize('schema_test/unique_model_id.sql'), + 'path': normalize('schema_test/unique_model_id.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test_unique(model=ref('model'), column_name='id') }}", 'refs': [['model']], 'resource_type': 'test', @@ -1026,18 +1013,18 @@ def expected_seeded_manifest(self, model_database=None): 'model.test.model': ['seed.test.seed'], 'seed.test.seed': [], 'test.test.not_null_model_id': ['model.test.model'], - 'test.test.nothing_model_': ['model.test.model'], + 'test.test.test_nothing_model_': ['model.test.model'], 'test.test.unique_model_id': ['model.test.model'], }, 'child_map': { 'model.test.model': [ 'test.test.not_null_model_id', - 'test.test.nothing_model_', + 'test.test.test_nothing_model_', 'test.test.unique_model_id', ], 'seed.test.seed': ['model.test.model'], 'test.test.not_null_model_id': [], - 'test.test.nothing_model_': [], + 'test.test.test_nothing_model_': [], 'test.test.unique_model_id': [], }, 'docs': { @@ -1049,6 +1036,56 @@ def expected_seeded_manifest(self, model_database=None): 'user_id': None, }, 'disabled': [], + 'files': { + normalize('macros/dummy_test.sql'): { + 'path': self._path_to('macros', 'dummy_test.sql'), + 'checksum': self._checksum_file('macros/dummy_test.sql'), + 'docs': [], + 'macros': ['macro.test.test_nothing'], + 'nodes': [], + 'sources': [], + 'patches': [], + }, + normalize('models/model.sql'): { + 'path': self._path_to('models', 'model.sql'), + 'checksum': self._checksum_file('models/model.sql'), + 'docs': [], + 'macros': [], + 'nodes': ['model.test.model'], + 'sources': [], + 'patches': [], + }, + normalize('seed/seed.csv'): { + 'path': self._path_to('seed', 'seed.csv'), + 'checksum': { + 'name': 'path', + 'checksum': self._path_to('seed', 'seed.csv')['absolute_path'], + }, + 'docs': [], + 'macros': [], + 'nodes': ['seed.test.seed'], + 'patches': [], + 'sources': [], + }, + normalize('models/readme.md'): { + 'path': self._path_to('models', 'readme.md'), + 'checksum': self._checksum_file('models/readme.md'), + 'docs': [], + 'macros': [], + 'nodes': [], + 'patches': [], + 'sources': [], + }, + normalize('models/schema.yml'): { + 'path': self._path_to('models', 'schema.yml'), + 'checksum': self._checksum_file('models/schema.yml'), + 'docs': [], + 'macros': [], + 'nodes': ['test.test.unique_model_id', 'test.test.not_null_model_id', 'test.test.test_nothing_model_'], + 'patches': ['model'], + 'sources': [], + }, + }, } def expected_postgres_references_manifest(self, model_database=None): @@ -1057,7 +1094,25 @@ def expected_postgres_references_manifest(self, model_database=None): config_vars = {'alternate_db': model_database} my_schema_name = self.unique_schema() docs_path = self.dir('ref_models/docs.md') - docs_file = LineIndifferent(_read_file(docs_path).lstrip()) + + ephemeral_summary = LineIndifferent( + '{% docs ephemeral_summary %}\nA summmary table of the ephemeral copy of the seed data\n{% enddocs %}' + ) + source_info = LineIndifferent('{% docs source_info %}\nMy source\n{% enddocs %}') + summary_count = LineIndifferent( + '{% docs summary_count %}\nThe number of instances of the first name\n{% enddocs %}' + ) + summary_first_name = LineIndifferent( + '{% docs summary_first_name %}\nThe first name being summarized\n{% enddocs %}' + ) + table_info = LineIndifferent('{% docs table_info %}\nMy table\n{% enddocs %}') + view_summary = LineIndifferent( + '{% docs view_summary %}\nA view of the summary of the ephemeral copy of the seed data\n{% enddocs %}' + ) + column_info = LineIndifferent( + '{% docs column_info %}\nAn ID field\n{% enddocs %}' + ) + return { 'nodes': { 'model.test.ephemeral_copy': { @@ -1257,7 +1312,7 @@ def expected_postgres_references_manifest(self, model_database=None): 'package_name': 'test', 'patch_path': None, 'path': 'seed.csv', - 'raw_sql': '-- csv --', + 'raw_sql': '', 'refs': [], 'resource_type': 'seed', 'root_path': OneOf(self.test_root_dir, self.initial_dir), @@ -1318,7 +1373,7 @@ def expected_postgres_references_manifest(self, model_database=None): 'dbt.__overview__': ANY, 'test.column_info': { 'block_contents': 'An ID field', - 'file_contents': docs_file, + 'file_contents': column_info, 'name': 'column_info', 'original_file_path': docs_path, 'package_name': 'test', @@ -1330,7 +1385,7 @@ def expected_postgres_references_manifest(self, model_database=None): 'block_contents': ( 'A summmary table of the ephemeral copy of the seed data' ), - 'file_contents': docs_file, + 'file_contents': ephemeral_summary, 'name': 'ephemeral_summary', 'original_file_path': docs_path, 'package_name': 'test', @@ -1340,7 +1395,7 @@ def expected_postgres_references_manifest(self, model_database=None): }, 'test.source_info': { 'block_contents': 'My source', - 'file_contents': docs_file, + 'file_contents': source_info, 'name': 'source_info', 'original_file_path': docs_path, 'package_name': 'test', @@ -1350,7 +1405,7 @@ def expected_postgres_references_manifest(self, model_database=None): }, 'test.summary_count': { 'block_contents': 'The number of instances of the first name', - 'file_contents': docs_file, + 'file_contents': summary_count, 'name': 'summary_count', 'original_file_path': docs_path, 'package_name': 'test', @@ -1360,7 +1415,7 @@ def expected_postgres_references_manifest(self, model_database=None): }, 'test.summary_first_name': { 'block_contents': 'The first name being summarized', - 'file_contents': docs_file, + 'file_contents': summary_first_name, 'name': 'summary_first_name', 'original_file_path': docs_path, 'package_name': 'test', @@ -1370,7 +1425,7 @@ def expected_postgres_references_manifest(self, model_database=None): }, 'test.table_info': { 'block_contents': 'My table', - 'file_contents': docs_file, + 'file_contents': table_info, 'name': 'table_info', 'original_file_path': docs_path, 'package_name': 'test', @@ -1383,7 +1438,7 @@ def expected_postgres_references_manifest(self, model_database=None): 'A view of the summary of the ephemeral copy of the ' 'seed data' ), - 'file_contents': docs_file, + 'file_contents': view_summary, 'name': 'view_summary', 'original_file_path': docs_path, 'package_name': 'test', @@ -1412,6 +1467,83 @@ def expected_postgres_references_manifest(self, model_database=None): 'user_id': None, }, 'disabled': [], + 'files': { + normalize('macros/dummy_test.sql'): { + 'checksum': self._checksum_file('macros/dummy_test.sql'), + 'docs': [], + 'nodes': [], + 'macros': ['macro.test.test_nothing'], + 'patches': [], + 'path': self._path_to('macros', 'dummy_test.sql'), + 'sources': [], + }, + normalize('ref_models/view_summary.sql'): { + 'checksum': self._checksum_file('ref_models/view_summary.sql'), + 'docs': [], + 'macros': [], + 'nodes': ['model.test.view_summary'], + 'patches': [], + 'path': self._path_to('ref_models', 'view_summary.sql'), + 'sources': [], + }, + normalize('ref_models/ephemeral_summary.sql'): { + 'checksum': self._checksum_file('ref_models/ephemeral_summary.sql'), + 'docs': [], + 'macros': [], + 'nodes': ['model.test.ephemeral_summary'], + 'patches': [], + 'path': self._path_to('ref_models', 'ephemeral_summary.sql'), + 'sources': [], + }, + normalize('ref_models/ephemeral_copy.sql'): { + 'checksum': self._checksum_file('ref_models/ephemeral_copy.sql'), + 'nodes': ['model.test.ephemeral_copy'], + 'docs': [], + 'macros': [], + 'patches': [], + 'path': self._path_to('ref_models', 'ephemeral_copy.sql'), + 'sources': [], + }, + normalize('seed/seed.csv'): { + 'checksum': { + 'name': 'path', + 'checksum': self._path_to('seed', 'seed.csv')['absolute_path'], + }, + 'docs': [], + 'macros': [], + 'nodes': ['seed.test.seed'], + 'patches': [], + 'path': self._path_to('seed', 'seed.csv'), + 'sources': [], + }, + normalize('ref_models/docs.md'): { + 'checksum': self._checksum_file('ref_models/docs.md'), + 'docs': [ + 'test.ephemeral_summary', + 'test.summary_first_name', + 'test.summary_count', + 'test.view_summary', + 'test.source_info', + 'test.table_info', + 'test.column_info', + ], + 'macros': [], + 'nodes': [], + 'patches': [], + 'path': self._path_to('ref_models', 'docs.md'), + 'sources': [], + }, + normalize('ref_models/schema.yml'): { + 'checksum': self._checksum_file('ref_models/schema.yml'), + 'docs': [], + 'macros': [], + 'nodes': [], + 'patches': ['ephemeral_summary', 'view_summary'], + 'path': self._path_to('ref_models', 'schema.yml'), + 'sources': ['source.test.my_source.my_table'], + }, + + }, } def expected_bigquery_complex_manifest(self): @@ -1639,7 +1771,7 @@ def expected_bigquery_complex_manifest(self): 'name': 'seed', 'root_path': OneOf(self.test_root_dir, self.initial_dir), 'resource_type': 'seed', - 'raw_sql': '-- csv --', + 'raw_sql': '', 'package_name': 'test', 'original_file_path': self.dir('seed/seed.csv'), 'refs': [], @@ -1693,12 +1825,104 @@ def expected_bigquery_complex_manifest(self): 'user_id': None, }, 'disabled': [], + 'files': { + normalize('macros/dummy_test.sql'): { + 'checksum': self._checksum_file('macros/dummy_test.sql'), + 'path': self._path_to('macros', 'dummy_test.sql'), + 'macros': ['macro.test.test_nothing'], + 'patches': [], + 'docs': [], + 'nodes': [], + 'sources': [], + }, + normalize('bq_models/clustered.sql'): { + 'checksum': self._checksum_file('bq_models/clustered.sql'), + 'path': self._path_to('bq_models', 'clustered.sql'), + 'nodes': ['model.test.clustered'], + 'patches': [], + 'docs': [], + 'macros': [], + 'sources': [], + }, + normalize('bq_models/multi_clustered.sql'): { + 'checksum': self._checksum_file('bq_models/multi_clustered.sql'), + 'path': self._path_to('bq_models', 'multi_clustered.sql'), + 'nodes': ['model.test.multi_clustered'], + 'patches': [], + 'docs': [], + 'macros': [], + 'sources': [], + }, + normalize('bq_models/nested_table.sql'): { + 'checksum': self._checksum_file('bq_models/nested_table.sql'), + 'path': self._path_to('bq_models', 'nested_table.sql'), + 'nodes': ['model.test.nested_table'], + 'patches': [], + 'docs': [], + 'macros': [], + 'sources': [], + }, + normalize('bq_models/nested_view.sql'): { + 'checksum': self._checksum_file('bq_models/nested_view.sql'), + 'path': self._path_to('bq_models', 'nested_view.sql'), + 'nodes': ['model.test.nested_view'], + 'patches': [], + 'docs': [], + 'macros': [], + 'sources': [], + }, + normalize('seed/seed.csv'): { + 'checksum': { + 'name': 'path', + 'checksum': self._path_to('seed', 'seed.csv')['absolute_path'], + }, + 'path': self._path_to('seed', 'seed.csv'), + 'nodes': ['seed.test.seed'], + 'patches': [], + 'docs': [], + 'macros': [], + 'sources': [], + }, + normalize('bq_models/schema.yml'): { + 'checksum': self._checksum_file('bq_models/schema.yml'), + 'path': self._path_to('bq_models', 'schema.yml'), + 'nodes': [], + 'patches': ['nested_view', 'clustered', 'multi_clustered'], + 'docs': [], + 'macros': [], + 'sources': [], + }, + }, + } + + def _checksum_file(self, path): + """windows has silly git behavior that adds newlines, and python does + silly things if we just open(..., 'r').encode('utf-8'). + """ + with open(self.dir(path), 'rb') as fp: + hashed = hashlib.sha256(fp.read()).hexdigest() + return { + 'name': 'sha256', + 'checksum': hashed, + } + + def _path_to(self, searched_path: str, relative_path: str): + if searched_path == '.': + absolute_path = os.path.join(self.test_root_dir, relative_path) + else: + absolute_path = os.path.join(self.test_root_dir, searched_path, relative_path) + + return { + 'searched_path': normalize(searched_path), + 'relative_path': normalize(relative_path), + 'absolute_path': normalize(absolute_path), } def expected_redshift_incremental_view_manifest(self): model_sql_path = self.dir('rs_models/model.sql') my_schema_name = self.unique_schema() config_vars = {'alternate_db': self.default_database} + return { 'nodes': { 'model.test.model': { @@ -1767,7 +1991,7 @@ def expected_redshift_incremental_view_manifest(self): 'name': 'seed', 'root_path': self.test_root_dir, 'resource_type': 'seed', - 'raw_sql': '-- csv --', + 'raw_sql': '', 'package_name': 'test', 'original_file_path': self.dir('seed/seed.csv'), 'refs': [], @@ -1815,30 +2039,85 @@ def expected_redshift_incremental_view_manifest(self): 'user_id': None, }, 'disabled': [], + 'files': { + normalize('macros/dummy_test.sql'): { + 'checksum': self._checksum_file('macros/dummy_test.sql'), + 'path': self._path_to('macros', 'dummy_test.sql'), + 'docs': [], + 'macros': ['macro.test.test_nothing'], + 'nodes': [], + 'patches': [], + 'sources': [], + }, + normalize('rs_models/model.sql'): { + 'checksum': self._checksum_file('rs_models/model.sql'), + 'path': self._path_to('rs_models', 'model.sql'), + 'docs': [], + 'macros': [], + 'nodes': ['model.test.model'], + 'patches': [], + 'sources': [], + }, + normalize('seed/seed.csv'): { + 'checksum': { + 'name': 'path', + 'checksum': self._path_to('seed', 'seed.csv')['absolute_path'], + }, + 'path': self._path_to('seed', 'seed.csv'), + 'docs': [], + 'macros': [], + 'nodes': ['seed.test.seed'], + 'patches': [], + 'sources': [], + }, + normalize('rs_models/schema.yml'): { + 'checksum': self._checksum_file('rs_models/schema.yml'), + 'path': self._path_to('rs_models', 'schema.yml'), + 'docs': [], + 'macros': [], + 'nodes': [], + 'patches': ['model'], + 'sources': [] + }, + }, } + def verify_files(self, got_files, expected_files): + # I'm sure this will be fun on windows. We just want to look at this + # project's files. + my_files = { + os.path.relpath(k, self.test_root_dir): v + for k, v in got_files.items() + if k.startswith(self.test_root_dir) + } + + self.assertEqual(set(my_files), set(expected_files)) + for k in my_files: + self.assertEqual(my_files[k], expected_files[k]) + def verify_manifest(self, expected_manifest): self.assertTrue(os.path.exists('./target/manifest.json')) manifest = _read_json('./target/manifest.json') - self.assertEqual( - set(manifest), - {'nodes', 'macros', 'parent_map', 'child_map', 'generated_at', - 'docs', 'metadata', 'docs', 'disabled'} - ) - - self.verify_manifest_macros(manifest) - manifest_without_extras = { - k: v for k, v in manifest.items() - if k not in {'macros', 'generated_at'} - } - self.assertBetween( - manifest['generated_at'], - start=self.generate_start_time - ) - self.assertEqual(manifest['disabled'], []) - self.assertEqual(manifest_without_extras, expected_manifest) + manifest_keys = frozenset({ + 'nodes', 'macros', 'parent_map', 'child_map', 'generated_at', + 'docs', 'metadata', 'docs', 'disabled', 'files' + }) + + self.assertEqual(frozenset(manifest), manifest_keys) + + for key in manifest_keys: + if key == 'macros': + self.verify_manifest_macros(manifest) + elif key == 'generated_at': + self.assertBetween(manifest['generated_at'], + start=self.generate_start_time) + elif key == 'files': + self.verify_files(manifest[key], expected_manifest[key]) + else: + self.assertIn(key, expected_manifest) # sanity check + self.assertEqual(manifest[key], expected_manifest[key]) def _quote(self, value): quote_char = '`' if self.adapter_type == 'bigquery' else '"' @@ -1896,7 +2175,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'warn': None, 'node': { 'alias': 'model', - 'build_path': _normalize( + 'build_path': normalize( 'target/compiled/test/model.sql' ), 'columns': { @@ -1948,12 +2227,10 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'warn': None, 'node': { 'alias': 'seed', - 'build_path': _normalize( - 'target/compiled/test/seed.csv' - ), + 'build_path': None, 'columns': {}, 'compiled': True, - 'compiled_sql': '-- csv --', + 'compiled_sql': '', 'config': { 'column_types': {}, 'enabled': True, @@ -1972,13 +2249,13 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'extra_ctes': [], 'extra_ctes_injected': True, 'fqn': ['test', 'seed'], - 'injected_sql': '-- csv --', + 'injected_sql': '', 'name': 'seed', 'original_file_path': self.dir('seed/seed.csv'), 'package_name': 'test', 'patch_path': None, 'path': 'seed.csv', - 'raw_sql': '-- csv --', + 'raw_sql': '', 'refs': [], 'resource_type': 'seed', 'root_path': OneOf(self.test_root_dir, self.initial_dir), @@ -2000,7 +2277,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'warn': None, 'node': { 'alias': 'not_null_model_id', - 'build_path': _normalize('target/compiled/test/schema_test/not_null_model_id.sql'), + 'build_path': normalize('target/compiled/test/schema_test/not_null_model_id.sql'), 'column_name': 'id', 'columns': {}, 'compiled': True, @@ -2029,7 +2306,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'original_file_path': schema_yml_path, 'package_name': 'test', 'patch_path': None, - 'path': _normalize('schema_test/not_null_model_id.sql'), + 'path': normalize('schema_test/not_null_model_id.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test_not_null(model=ref('model'), column_name='id') }}", 'refs': [['model']], 'resource_type': 'test', @@ -2051,8 +2328,8 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'fail': None, 'warn': None, 'node': { - 'alias': 'nothing_model_', - 'build_path': _normalize('target/compiled/test/schema_test/nothing_model_.sql'), + 'alias': 'test_nothing_model_', + 'build_path': normalize('target/compiled/test/schema_test/test_nothing_model_.sql'), 'column_name': None, 'columns': {}, 'compiled': True, @@ -2075,13 +2352,13 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'docrefs': [], 'extra_ctes': [], 'extra_ctes_injected': True, - 'fqn': ['test', 'schema_test', 'nothing_model_'], + 'fqn': ['test', 'schema_test', 'test_nothing_model_'], 'injected_sql': AnyStringWith('select 0'), - 'name': 'nothing_model_', + 'name': 'test_nothing_model_', 'original_file_path': schema_yml_path, 'package_name': 'test', 'patch_path': None, - 'path': _normalize('schema_test/nothing_model_.sql'), + 'path': normalize('schema_test/test_nothing_model_.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test.test_nothing(model=ref('model'), ) }}", 'refs': [['model']], 'resource_type': 'test', @@ -2089,7 +2366,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'schema': schema, 'sources': [], 'tags': ['schema'], - 'unique_id': 'test.test.nothing_model_', + 'unique_id': 'test.test.test_nothing_model_', 'wrapped_sql': AnyStringWith('select 0'), }, 'thread_id': ANY, @@ -2104,7 +2381,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'warn': None, 'node': { 'alias': 'unique_model_id', - 'build_path': _normalize('target/compiled/test/schema_test/unique_model_id.sql'), + 'build_path': normalize('target/compiled/test/schema_test/unique_model_id.sql'), 'column_name': 'id', 'columns': {}, 'compiled': True, @@ -2133,7 +2410,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'original_file_path': schema_yml_path, 'package_name': 'test', 'patch_path': None, - 'path': _normalize('schema_test/unique_model_id.sql'), + 'path': normalize('schema_test/unique_model_id.sql'), 'raw_sql': "{{ config(severity='ERROR') }}{{ test_unique(model=ref('model'), column_name='id') }}", 'refs': [['model']], 'resource_type': 'test', @@ -2183,7 +2460,7 @@ def expected_postgres_references_run_results(self): 'warn': None, 'node': { 'alias': 'ephemeral_summary', - 'build_path': _normalize( + 'build_path': normalize( 'target/compiled/test/ephemeral_summary.sql' ), 'columns': { @@ -2272,7 +2549,7 @@ def expected_postgres_references_run_results(self): 'warn': None, 'node': { 'alias': 'view_summary', - 'build_path': _normalize( + 'build_path': normalize( 'target/compiled/test/view_summary.sql' ), 'alias': 'view_summary', @@ -2360,12 +2637,10 @@ def expected_postgres_references_run_results(self): 'warn': None, 'node': { 'alias': 'seed', - 'build_path': _normalize( - 'target/compiled/test/seed.csv' - ), + 'build_path': None, 'columns': {}, 'compiled': True, - 'compiled_sql': '-- csv --', + 'compiled_sql': '', 'config': { 'column_types': {}, 'enabled': True, @@ -2384,13 +2659,13 @@ def expected_postgres_references_run_results(self): 'extra_ctes': [], 'extra_ctes_injected': True, 'fqn': ['test', 'seed'], - 'injected_sql': '-- csv --', + 'injected_sql': '', 'name': 'seed', 'original_file_path': self.dir('seed/seed.csv'), 'package_name': 'test', 'patch_path': None, 'path': 'seed.csv', - 'raw_sql': '-- csv --', + 'raw_sql': '', 'refs': [], 'resource_type': 'seed', 'root_path': OneOf(self.test_root_dir, self.initial_dir), diff --git a/test/integration/033_event_tracking_test/test_events.py b/test/integration/033_event_tracking_test/test_events.py index de62a3956db..29b5e3cbd1b 100644 --- a/test/integration/033_event_tracking_test/test_events.py +++ b/test/integration/033_event_tracking_test/test_events.py @@ -271,7 +271,7 @@ def seed_context(project_id, user_id, invocation_id, version): 'model_materialization': 'seed', 'execution_time': ANY, - 'hashed_contents': '4f67ae18b42bc9468cc95ca0dab30531', + 'hashed_contents': 'd41d8cd98f00b204e9800998ecf8427e', 'model_id': '39bc2cd707d99bd3e600d2faaafad7ae', 'index': 1, diff --git a/test/integration/047_dbt_ls_test/test_ls.py b/test/integration/047_dbt_ls_test/test_ls.py index 2b0e9679851..bed65015d21 100644 --- a/test/integration/047_dbt_ls_test/test_ls.py +++ b/test/integration/047_dbt_ls_test/test_ls.py @@ -59,7 +59,7 @@ def expect_given_output(self, args, expectations): def expect_snapshot_output(self): expectations = { 'name': 'my_snapshot', - 'selector': 'test.my_snapshot', + 'selector': 'test.snapshot.my_snapshot', 'json': { 'name': 'my_snapshot', 'package_name': 'test', @@ -331,8 +331,8 @@ def expect_all_output(self): # but models don't! they just have (package.name) # sources are like models - (package.source_name.table_name) expected_default = { - 'test.my_snapshot', 'test.ephemeral', + 'test.snapshot.my_snapshot', 'test.sub.inner', 'test.outer', 'test.seed', diff --git a/test/integration/base.py b/test/integration/base.py index c12154e703a..ce21efd9e7d 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -29,6 +29,19 @@ INITIAL_ROOT = os.getcwd() +def normalize(path): + """On windows, neither is enough on its own: + + >>> normcase('C:\\documents/ALL CAPS/subdir\\..') + 'c:\\documents\\all caps\\subdir\\..' + >>> normpath('C:\\documents/ALL CAPS/subdir\\..') + 'C:\\documents\\ALL CAPS' + >>> normpath(normcase('C:\\documents/ALL CAPS/subdir\\..')) + 'c:\\documents\\all caps' + """ + return os.path.normcase(os.path.normpath(path)) + + class FakeArgs: def __init__(self): self.threads = 1 @@ -300,7 +313,7 @@ def setUp(self): _really_makedirs(self._logs_dir) self.test_original_source_path = _pytest_get_test_root() print('test_original_source_path={}'.format(self.test_original_source_path)) - self.test_root_dir = tempfile.mkdtemp(prefix='dbt-int-test-') + self.test_root_dir = normalize(tempfile.mkdtemp(prefix='dbt-int-test-')) print('test_root_dir={}'.format(self.test_root_dir)) os.chdir(self.test_root_dir) try: diff --git a/test/unit/test_agate_helper.py b/test/unit/test_agate_helper.py index e39afcba009..f5705f62ba9 100644 --- a/test/unit/test_agate_helper.py +++ b/test/unit/test_agate_helper.py @@ -1,4 +1,3 @@ -from __future__ import unicode_literals import unittest from datetime import datetime diff --git a/test/unit/test_compiler.py b/test/unit/test_compiler.py index eb3a27b3944..3a9dc92b77f 100644 --- a/test/unit/test_compiler.py +++ b/test/unit/test_compiler.py @@ -108,7 +108,8 @@ def test__prepend_ctes__already_has_cte(self): docs={}, # '2018-02-14T09:15:13Z' generated_at=datetime(2018, 2, 14, 9, 15, 13), - disabled=[] + disabled=[], + files={}, ) result, output_graph = dbt.compilation.prepend_ctes( @@ -185,7 +186,8 @@ def test__prepend_ctes__no_ctes(self): }, docs={}, generated_at='2018-02-14T09:15:13Z', - disabled=[] + disabled=[], + files={}, ) result, output_graph = dbt.compilation.prepend_ctes( @@ -269,7 +271,8 @@ def test__prepend_ctes(self): }, docs={}, generated_at='2018-02-14T09:15:13Z', - disabled=[] + disabled=[], + files={}, ) result, output_graph = dbt.compilation.prepend_ctes( @@ -370,7 +373,8 @@ def test__prepend_ctes__multiple_levels(self): }, docs={}, generated_at='2018-02-14T09:15:13Z', - disabled=[] + disabled=[], + files={}, ) result, output_graph = dbt.compilation.prepend_ctes( diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 86bca572c7e..b8eb444dd14 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -17,6 +17,8 @@ from dbt.semver import VersionSpecifier from dbt.task.run_operation import RunOperationTask +from .utils import normalize + INITIAL_ROOT = os.getcwd() @@ -168,8 +170,8 @@ def setUp(self): class BaseFileTest(BaseConfigTest): def setUp(self): - self.project_dir = os.path.normpath(tempfile.mkdtemp()) - self.profiles_dir = os.path.normpath(tempfile.mkdtemp()) + self.project_dir = normalize(tempfile.mkdtemp()) + self.profiles_dir = normalize(tempfile.mkdtemp()) super().setUp() def tearDown(self): @@ -919,7 +921,8 @@ def test_run_operation_task(self): self.assertEqual(os.getcwd(), INITIAL_ROOT) self.assertNotEqual(INITIAL_ROOT, self.project_dir) new_task = RunOperationTask.from_args(self.args) - self.assertEqual(os.getcwd(), self.project_dir) + self.assertEqual(os.path.realpath(os.getcwd()), + os.path.realpath(self.project_dir)) def test_run_operation_task_with_bad_path(self): self.args.project_dir = 'bad_path' diff --git a/test/unit/test_contracts_graph_manifest.py b/test/unit/test_contracts_graph_manifest.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/unit/test_docs_blocks.py b/test/unit/test_docs_blocks.py index 213d745d958..0676683e7fc 100644 --- a/test/unit/test_docs_blocks.py +++ b/test/unit/test_docs_blocks.py @@ -1,10 +1,12 @@ import os import unittest -from unittest import mock +from dbt.contracts.graph.manifest import SourceFile, FileHash, FilePath, Manifest +from dbt.contracts.graph.parsed import ParsedDocumentation from dbt.node_types import NodeType from dbt.parser import docs -from dbt.contracts.graph.unparsed import UnparsedDocumentationFile +from dbt.parser.results import ParseResult +from dbt.parser.search import FileBlock from .utils import config_from_parts_or_dicts @@ -34,18 +36,31 @@ a cookie then expires after 30 minutes of inactivity. ''' - -TEST_DOCUMENTATION_FILE = r''' +SNOWPLOW_SESSIONS_BLOCK = r''' {{% docs snowplow_sessions %}} {snowplow_sessions_docs} {{% enddocs %}} +'''.format( + snowplow_sessions_docs=SNOWPLOW_SESSIONS_DOCS +).strip() + +SNOWPLOW_SESSIONS_SESSION_ID_BLOCK = r''' {{% docs snowplow_sessions__session_id %}} {snowplow_sessions_session_id_docs} {{% enddocs %}} '''.format( - snowplow_sessions_docs=SNOWPLOW_SESSIONS_DOCS, snowplow_sessions_session_id_docs=SNOWPLOW_SESSIONS_SESSION_ID_DOCS +).strip() + + +TEST_DOCUMENTATION_FILE = r''' +{sessions_block} + +{session_id_block} +'''.format( + sessions_block=SNOWPLOW_SESSIONS_BLOCK, + session_id_block=SNOWPLOW_SESSIONS_SESSION_ID_BLOCK, ) @@ -95,56 +110,57 @@ def setUp(self): project=subdir_project, profile=profile_data ) - @mock.patch('dbt.parser.docs.system') - def test_load_file(self, system): - system.load_file_contents.return_value = TEST_DOCUMENTATION_FILE - system.find_matching.return_value = [{ - 'relative_path': 'test_file.md', - 'absolute_path': self.testfile_path, - 'searched_path': self.subdir_path, - }] - results = list(docs.DocumentationParser.load_file( - 'some_package', self.root_path, ['test_subdir']) - ) - self.assertEqual(len(results), 1) - result = results[0] - self.assertEqual(result.package_name, 'some_package') - self.assertEqual(result.file_contents, TEST_DOCUMENTATION_FILE) - self.assertEqual(result.original_file_path, - self.testfile_path) - self.assertEqual(result.root_path, self.root_path) - self.assertEqual(result.resource_type, NodeType.Documentation) - self.assertEqual(result.path, 'test_file.md') - - def test_parse(self): - docfile = UnparsedDocumentationFile( - root_path=self.root_path, - path='test_file.md', - original_file_path=self.testfile_path, - package_name='some_package', - file_contents=TEST_DOCUMENTATION_FILE + def _build_file(self, contents, relative_path) -> FileBlock: + match = FilePath( + relative_path=relative_path, + absolute_path=self.testfile_path, + searched_path=self.subdir_path, ) - all_projects = { - 'root': self.root_project_config, - 'some_package': self.subdir_project_config - } - parser = docs.DocumentationParser(self.root_project_config, all_projects) - parsed = list(parser.parse(docfile)) - parsed.sort(key=lambda x: x.name) - - self.assertEqual(len(parsed), 2) - table = parsed[0] - column = parsed[1] - self.assertEqual(table.name, 'snowplow_sessions') - self.assertEqual(table.unique_id, - 'some_package.snowplow_sessions') - self.assertEqual(table.block_contents, SNOWPLOW_SESSIONS_DOCS.strip()) - - self.assertEqual(column.name, 'snowplow_sessions__session_id') - self.assertEqual(column.unique_id, - 'some_package.snowplow_sessions__session_id') - self.assertEqual( - column.block_contents, - SNOWPLOW_SESSIONS_SESSION_ID_DOCS.strip() - ) - + source_file = SourceFile(path=match, checksum=FileHash.empty()) + source_file.contents = contents + return FileBlock(file=source_file) + + def test_load_file(self): + parser = docs.DocumentationParser( + results=ParseResult.rpc(), + root_project=self.root_project_config, + project=self.subdir_project_config, + macro_manifest=Manifest.from_macros()) + + file_block = self._build_file(TEST_DOCUMENTATION_FILE, 'test_file.md') + + parser.parse_file(file_block) + results = sorted(parser.results.docs.values(), key=lambda n: n.name) + self.assertEqual(len(results), 2) + for result in results: + self.assertIsInstance(result, ParsedDocumentation) + self.assertEqual(result.package_name, 'some_package') + self.assertNotEqual(result.file_contents, TEST_DOCUMENTATION_FILE) + self.assertEqual(result.original_file_path, self.testfile_path) + self.assertEqual(result.root_path, self.subdir_path) + self.assertEqual(result.resource_type, NodeType.Documentation) + self.assertEqual(result.path, 'test_file.md') + + self.assertEqual(results[0].name, 'snowplow_sessions') + self.assertEqual(results[0].file_contents, SNOWPLOW_SESSIONS_BLOCK) + self.assertEqual(results[1].name, 'snowplow_sessions__session_id') + self.assertEqual(results[1].file_contents, SNOWPLOW_SESSIONS_SESSION_ID_BLOCK) + + def test_load_file_extras(self): + TEST_DOCUMENTATION_FILE + '{% model foo %}select 1 as id{% endmodel %}' + + parser = docs.DocumentationParser( + results=ParseResult.rpc(), + root_project=self.root_project_config, + project=self.subdir_project_config, + macro_manifest=Manifest.from_macros()) + + file_block = self._build_file(TEST_DOCUMENTATION_FILE, 'test_file.md') + + parser.parse_file(file_block) + results = sorted(parser.results.docs.values(), key=lambda n: n.name) + self.assertEqual(len(results), 2) + for result in results: + self.assertIsInstance(result, ParsedDocumentation) + self.assertEqual(results[0].name, 'snowplow_sessions') + self.assertEqual(results[1].name, 'snowplow_sessions__session_id') diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 2436756739e..fb9b0ef001d 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -7,9 +7,14 @@ import dbt.exceptions import dbt.flags import dbt.linker +import dbt.parser import dbt.config import dbt.utils import dbt.loader +from dbt.contracts.graph.manifest import FilePath, SourceFile, FileHash +from dbt.parser.results import ParseResult +from dbt.parser.base import BaseParser +from dbt.parser.search import FileBlock try: from queue import Empty @@ -26,9 +31,12 @@ class GraphTest(unittest.TestCase): def tearDown(self): self.write_gpickle_patcher.stop() self.load_projects_patcher.stop() - self.find_matching_patcher.stop() - self.load_file_contents_patcher.stop() + self.file_system_patcher.stop() self.get_adapter_patcher.stop() + self.mock_filesystem_constructor.stop() + self.mock_hook_constructor.stop() + self.load_patch.stop() + self.load_source_file_ptcher.stop() def setUp(self): dbt.flags.STRICT_MODE = True @@ -36,8 +44,12 @@ def setUp(self): self.write_gpickle_patcher = patch('networkx.write_gpickle') self.load_projects_patcher = patch('dbt.loader._load_projects') - self.find_matching_patcher = patch('dbt.clients.system.find_matching') - self.load_file_contents_patcher = patch('dbt.clients.system.load_file_contents') + self.file_system_patcher = patch.object( + dbt.parser.search.FilesystemSearcher, '__new__' + ) + self.hook_patcher = patch.object( + dbt.parser.hooks.HookParser, '__new__' + ) self.get_adapter_patcher = patch('dbt.context.parser.get_adapter') self.factory = self.get_adapter_patcher.start() @@ -68,27 +80,36 @@ def _load_projects(config, paths): self.mock_load_projects.side_effect = _load_projects self.mock_models = [] - self.mock_content = {} - def mock_find_matching(root_path, relative_paths_to_search, file_pattern): - if 'sql' not in file_pattern: - return [] + self.load_patch = patch('dbt.loader.make_parse_result') + self.mock_parse_result = self.load_patch.start() + self.mock_parse_result.return_value = ParseResult.rpc() - to_return = [] + self.load_source_file_ptcher = patch.object(BaseParser, 'load_file') + self.mock_source_file = self.load_source_file_ptcher.start() + self.mock_source_file.side_effect = lambda path: [n for n in self.mock_models if n.path == path][0] - if 'models' in relative_paths_to_search: - to_return = to_return + self.mock_models - - return to_return + def filesystem_iter(iter_self): + if 'sql' not in iter_self.extension: + return [] + if 'models' not in iter_self.relative_dirs: + return [] + return [model.path for model in self.mock_models] - self.mock_find_matching = self.find_matching_patcher.start() - self.mock_find_matching.side_effect = mock_find_matching + def create_filesystem_searcher(cls, project, relative_dirs, extension): + result = MagicMock(project=project, relative_dirs=relative_dirs, extension=extension) + result.__iter__.side_effect = lambda: iter(filesystem_iter(result)) + return result - def mock_load_file_contents(path): - return self.mock_content[path] + def create_hook_patcher(cls, results, project, relative_dirs, extension): + result = MagicMock(results=results, project=project, relative_dirs=relative_dirs, extension=extension) + result.__iter__.side_effect = lambda: iter([]) + return result - self.mock_load_file_contents = self.load_file_contents_patcher.start() - self.mock_load_file_contents.side_effect = mock_load_file_contents + self.mock_filesystem_constructor = self.file_system_patcher.start() + self.mock_filesystem_constructor.side_effect = create_filesystem_searcher + self.mock_hook_constructor = self.hook_patcher.start() + self.mock_hook_constructor.side_effect = create_hook_patcher def get_config(self, extra_cfg=None): if extra_cfg is None: @@ -109,12 +130,20 @@ def get_compiler(self, project): def use_models(self, models): for k, v in models.items(): - path = os.path.abspath('models/{}.sql'.format(k)) - self.mock_models.append({ - 'searched_path': 'models', - 'absolute_path': path, - 'relative_path': '{}.sql'.format(k)}) - self.mock_content[path] = v + path = FilePath( + searched_path='models', + absolute_path=os.path.normcase(os.path.abspath('models/{}.sql'.format(k))), + relative_path='{}.sql'.format(k), + ) + source_file = SourceFile(path=path, checksum=FileHash.empty()) + source_file.contents = v + self.mock_models.append(source_file) + + def load_manifest(self, config): + loader = dbt.loader.GraphLoader(config, {config.project_name: config}) + loader.load() + return loader.create_manifest() + def test__single_model(self): self.use_models({ @@ -122,7 +151,8 @@ def test__single_model(self): }) config = self.get_config() - manifest = dbt.loader.GraphLoader.load_all(config) + manifest = self.load_manifest(config) + compiler = self.get_compiler(config) linker = compiler.compile(manifest) @@ -141,7 +171,7 @@ def test__two_models_simple_ref(self): }) config = self.get_config() - manifest = dbt.loader.GraphLoader.load_all(config) + manifest = self.load_manifest(config) compiler = self.get_compiler(config) linker = compiler.compile(manifest) @@ -178,7 +208,7 @@ def test__model_materializations(self): } config = self.get_config(cfg) - manifest = dbt.loader.GraphLoader.load_all(config) + manifest = self.load_manifest(config) compiler = self.get_compiler(config) linker = compiler.compile(manifest) @@ -211,7 +241,7 @@ def test__model_incremental(self): } config = self.get_config(cfg) - manifest = dbt.loader.GraphLoader.load_all(config) + manifest = self.load_manifest(config) compiler = self.get_compiler(config) linker = compiler.compile(manifest) @@ -235,9 +265,9 @@ def test__dependency_list(self): }) config = self.get_config() - graph = dbt.loader.GraphLoader.load_all(config) + manifest = self.load_manifest(config) compiler = self.get_compiler(config) - linker = compiler.compile(graph) + linker = compiler.compile(manifest) models = ('model_1', 'model_2', 'model_3', 'model_4') model_ids = ['model.test_models_compile.{}'.format(m) for m in models] diff --git a/test/unit/test_loader.py b/test/unit/test_loader.py new file mode 100644 index 00000000000..9324694f2bd --- /dev/null +++ b/test/unit/test_loader.py @@ -0,0 +1,145 @@ +import unittest +from unittest import mock +from os.path import join as pjoin + +from .utils import config_from_parts_or_dicts, normalize + +from dbt import loader +from dbt.contracts.graph.manifest import FileHash, FilePath, SourceFile +from dbt.parser import ParseResult +from dbt.parser.search import FileBlock + + +class MatchingHash(FileHash): + def __init__(self): + return super().__init__('', '') + + def __eq__(self, other): + return True + + +class MismatchedHash(FileHash): + def __init__(self): + return super().__init__('', '') + + def __eq__(self, other): + return False + + +class TestLoader(unittest.TestCase): + def setUp(self): + profile_data = { + 'target': 'test', + 'quoting': {}, + 'outputs': { + 'test': { + 'type': 'redshift', + 'host': 'localhost', + 'schema': 'analytics', + 'user': 'test', + 'pass': 'test', + 'dbname': 'test', + 'port': 1, + } + } + } + + root_project = { + 'name': 'root', + 'version': '0.1', + 'profile': 'test', + 'project-root': normalize('/usr/src/app'), + } + + self.root_project_config = config_from_parts_or_dicts( + project=root_project, + profile=profile_data, + cli_vars='{"test_schema_name": "foo"}' + ) + self.parser = mock.MagicMock() + self.patched_result_builder = mock.patch('dbt.loader.make_parse_result') + self.mock_result_builder = self.patched_result_builder.start() + self.patched_result_builder.return_value = self._new_results() + self.loader = loader.GraphLoader( + self.root_project_config, + {'root': self.root_project_config} + ) + + def _new_results(self): + return ParseResult(MatchingHash(), MatchingHash(), {}) + + def _mismatched_file(self, searched, name): + return self._new_file(searched, name, False) + + def _matching_file(self, searched, name): + return self._new_file(searched, name, True) + + def _new_file(self, searched, name, match): + if match: + checksum = MatchingHash() + else: + checksum = MismatchedHash() + path = FilePath( + searched_path=normalize(searched), + relative_path=normalize(name), + absolute_path=normalize(pjoin(self.root_project_config.project_root, searched, name)), + ) + return SourceFile(path=path, checksum=checksum) + + def test_model_no_cache(self): + source_file = self._matching_file('models', 'model_1.sql') + self.parser.load_file.return_value = source_file + + old_results = None + + self.loader.parse_with_cache(source_file.path, self.parser, old_results) + # there was nothing in the cache, so parse_file should get called + # with a FileBlock that has the given source file in it + self.parser.parse_file.assert_called_once_with(FileBlock(file=source_file)) + + def test_model_cache_hit(self): + source_file = self._matching_file('models', 'model_1.sql') + self.parser.load_file.return_value = source_file + + source_file_dupe = self._matching_file('models', 'model_1.sql') + source_file_dupe.nodes.append('model.root.model_1') + + old_results = self._new_results() + old_results.files[source_file_dupe.path.search_key] = source_file_dupe + old_results.nodes = {'model.root.model_1': mock.MagicMock()} + + self.loader.parse_with_cache(source_file.path, self.parser, old_results) + # there was a cache hit, so parse_file should never have been called + self.parser.parse_file.assert_not_called() + + def test_model_cache_mismatch_checksum(self): + source_file = self._mismatched_file('models', 'model_1.sql') + self.parser.load_file.return_value = source_file + + source_file_dupe = self._mismatched_file('models', 'model_1.sql') + source_file_dupe.nodes.append('model.root.model_1') + + old_results = self._new_results() + old_results.files[source_file_dupe.path.search_key] = source_file_dupe + old_results.nodes = {'model.root.model_1': mock.MagicMock()} + + self.loader.parse_with_cache(source_file.path, self.parser, old_results) + # there was a cache checksum mismatch, so parse_file should get called + # with a FileBlock that has the given source file in it + self.parser.parse_file.assert_called_once_with(FileBlock(file=source_file)) + + def test_model_cache_missing_file(self): + source_file = self._matching_file('models', 'model_1.sql') + self.parser.load_file.return_value = source_file + + source_file_different = self._matching_file('models', 'model_2.sql') + source_file_different.nodes.append('model.root.model_2') + + old_results = self._new_results() + old_results.files[source_file_different.path.search_key] = source_file_different + old_results.nodes = {'model.root.model_2': mock.MagicMock()} + + self.loader.parse_with_cache(source_file.path, self.parser, old_results) + # the filename wasn't in the cache, so parse_file should get called + # with a FileBlock that has the given source file in it. + self.parser.parse_file.assert_called_once_with(FileBlock(file=source_file)) diff --git a/test/unit/test_manifest.py b/test/unit/test_manifest.py index 6f3068d66f5..7f2da430940 100644 --- a/test/unit/test_manifest.py +++ b/test/unit/test_manifest.py @@ -168,7 +168,8 @@ def setUp(self): @freezegun.freeze_time('2018-02-14T09:15:13Z') def test__no_nodes(self): manifest = Manifest(nodes={}, macros={}, docs={}, - generated_at=datetime.utcnow(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[], + files={}) self.assertEqual( manifest.writable_manifest().to_dict(), { @@ -180,6 +181,7 @@ def test__no_nodes(self): 'docs': {}, 'metadata': {}, 'disabled': [], + 'files': {}, } ) @@ -187,11 +189,13 @@ def test__no_nodes(self): def test__nested_nodes(self): nodes = copy.copy(self.nested_nodes) manifest = Manifest(nodes=nodes, macros={}, docs={}, - generated_at=datetime.utcnow(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[], + files={}) serialized = manifest.writable_manifest().to_dict() self.assertEqual(serialized['generated_at'], '2018-02-14T09:15:13Z') self.assertEqual(serialized['docs'], {}) self.assertEqual(serialized['disabled'], []) + self.assertEqual(serialized['files'], {}) parent_map = serialized['parent_map'] child_map = serialized['child_map'] # make sure there aren't any extra/missing keys. @@ -251,7 +255,8 @@ def test__nested_nodes(self): def test__to_flat_graph(self): nodes = copy.copy(self.nested_nodes) manifest = Manifest(nodes=nodes, macros={}, docs={}, - generated_at=datetime.utcnow(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[], + files={}) flat_graph = manifest.to_flat_graph() flat_nodes = flat_graph['nodes'] self.assertEqual(set(flat_graph), set(['nodes'])) @@ -285,7 +290,7 @@ def test_no_nodes_with_metadata(self, mock_user): config.hashed_name.return_value = '098f6bcd4621d373cade4e832627b4f6' manifest = Manifest(nodes={}, macros={}, docs={}, generated_at=datetime.utcnow(), disabled=[], - config=config) + config=config, files={}) metadata = { 'project_id': '098f6bcd4621d373cade4e832627b4f6', 'user_id': 'cfc9500f-dc7f-4c83-9ea7-2c581c1b38cf', @@ -306,12 +311,14 @@ def test_no_nodes_with_metadata(self, mock_user): 'send_anonymous_usage_stats': False, }, 'disabled': [], + 'files': {}, } ) def test_get_resource_fqns_empty(self): manifest = Manifest(nodes={}, macros={}, docs={}, - generated_at=datetime.utcnow(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[], + files={}) self.assertEqual(manifest.get_resource_fqns(), {}) def test_get_resource_fqns(self): @@ -336,7 +343,8 @@ def test_get_resource_fqns(self): raw_sql='-- csv --' ) manifest = Manifest(nodes=nodes, macros={}, docs={}, - generated_at=datetime.utcnow(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[], + files={}) expect = { 'models': frozenset([ ('snowplow', 'events'), @@ -500,7 +508,8 @@ def setUp(self): @freezegun.freeze_time('2018-02-14T09:15:13Z') def test__no_nodes(self): manifest = Manifest(nodes={}, macros={}, docs={}, - generated_at=datetime.utcnow(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[], + files={}) self.assertEqual( manifest.writable_manifest().to_dict(), { @@ -512,6 +521,7 @@ def test__no_nodes(self): 'docs': {}, 'metadata': {}, 'disabled': [], + 'files': {}, } ) @@ -519,7 +529,8 @@ def test__no_nodes(self): def test__nested_nodes(self): nodes = copy.copy(self.nested_nodes) manifest = Manifest(nodes=nodes, macros={}, docs={}, - generated_at=datetime.utcnow(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[], + files={}) serialized = manifest.writable_manifest().to_dict() self.assertEqual(serialized['generated_at'], '2018-02-14T09:15:13Z') self.assertEqual(serialized['disabled'], []) @@ -582,7 +593,8 @@ def test__nested_nodes(self): def test__to_flat_graph(self): nodes = copy.copy(self.nested_nodes) manifest = Manifest(nodes=nodes, macros={}, docs={}, - generated_at=datetime.utcnow(), disabled=[]) + generated_at=datetime.utcnow(), disabled=[], + files={}) flat_graph = manifest.to_flat_graph() flat_nodes = flat_graph['nodes'] self.assertEqual(set(flat_graph), set(['nodes'])) diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index d74477940bb..f22988d1198 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -1,32 +1,42 @@ import unittest from unittest import mock -from datetime import datetime import os import yaml import dbt.flags import dbt.parser -from dbt.parser import ModelParser, MacroParser, DataTestParser, \ - SchemaParser, ParserUtils -from dbt.parser.source_config import SourceConfig - -from dbt.node_types import NodeType -from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.parsed import ParsedModelNode, ParsedMacro, \ - ParsedNodePatch, ParsedSourceDefinition, NodeConfig, DependsOn, \ - ColumnInfo, ParsedTestNode, TestConfig -from dbt.contracts.graph.unparsed import FreshnessThreshold, Quoting, Time, \ - TimePeriod - -from .utils import config_from_parts_or_dicts - - -def get_os_path(unix_path): - return os.path.normpath(unix_path) +from dbt.exceptions import CompilationException +from dbt.parser import ( + ModelParser, MacroParser, DataTestParser, SchemaParser, ParserUtils, + ParseResult, SnapshotParser, AnalysisParser +) +from dbt.parser.search import FileBlock +from dbt.parser.schema_test_builders import YamlBlock + +from dbt.node_types import ( + NodeType, SnapshotType, MacroType, SourceType, TestType, AnalysisType +) +from dbt.contracts.graph.manifest import ( + Manifest, FilePath, SourceFile, FileHash +) +from dbt.contracts.graph.parsed import ( + ParsedModelNode, ParsedMacro, ParsedNodePatch, ParsedSourceDefinition, + NodeConfig, DependsOn, ColumnInfo, ParsedTestNode, TestConfig, + ParsedSnapshotNode, TimestampSnapshotConfig, TimestampStrategy, + ParsedAnalysisNode +) +from dbt.contracts.graph.unparsed import FreshnessThreshold + +from .utils import config_from_parts_or_dicts, normalize + + +def get_abs_os_path(unix_path): + return os.path.abspath(normalize(unix_path)) class BaseParserTest(unittest.TestCase): + maxDiff = None def setUp(self): dbt.flags.STRICT_MODE = True @@ -54,7 +64,7 @@ def setUp(self): 'name': 'root', 'version': '0.1', 'profile': 'test', - 'project-root': os.path.abspath('.'), + 'project-root': normalize('/usr/src/app'), } self.root_project_config = config_from_parts_or_dicts( @@ -67,7 +77,7 @@ def setUp(self): 'name': 'snowplow', 'version': '0.1', 'profile': 'test', - 'project-root': os.path.abspath('./dbt_modules/snowplow'), + 'project-root': get_abs_os_path('./dbt_modules/snowplow'), } self.snowplow_project_config = config_from_parts_or_dicts( @@ -81,2248 +91,601 @@ def setUp(self): self.patcher = mock.patch('dbt.context.parser.get_adapter') self.factory = self.patcher.start() + self.macro_manifest = Manifest.from_macros() + def tearDown(self): self.patcher.stop() - -class SourceConfigTest(BaseParserTest): - def test__source_config_single_call(self): - cfg = SourceConfig(self.root_project_config, self.root_project_config, - ['root', 'x'], NodeType.Model) - cfg.update_in_model_config({ - 'materialized': 'something', - 'sort': 'my sort key', - 'pre-hook': 'my pre run hook', - 'vars': {'a': 1, 'b': 2}, - }) - expect = { - 'column_types': {}, - 'enabled': True, - 'materialized': 'something', - 'post-hook': [], - 'pre-hook': ['my pre run hook'], - 'persist_docs': {}, - 'quoting': {}, - 'sort': 'my sort key', - 'tags': [], - 'vars': {'a': 1, 'b': 2}, - } - self.assertEqual(cfg.config, expect) - - def test__source_config_multiple_calls(self): - cfg = SourceConfig(self.root_project_config, self.root_project_config, - ['root', 'x'], NodeType.Model) - cfg.update_in_model_config({ - 'materialized': 'something', - 'sort': 'my sort key', - 'pre-hook': 'my pre run hook', - 'vars': {'a': 1, 'b': 2}, - }) - cfg.update_in_model_config({ - 'materialized': 'something else', - 'pre-hook': ['my other pre run hook', 'another pre run hook'], - 'vars': {'a': 4, 'c': 3}, - }) - expect = { - 'column_types': {}, - 'enabled': True, - 'materialized': 'something else', - 'persist_docs': {}, - 'post-hook': [], - 'pre-hook': [ - 'my pre run hook', - 'my other pre run hook', - 'another pre run hook', - ], - 'quoting': {}, - 'sort': 'my sort key', - 'tags': [], - 'vars': {'a': 4, 'b': 2, 'c': 3}, - } - self.assertEqual(cfg.config, expect) - - def test_source_config_all_keys_accounted_for(self): - used_keys = frozenset(SourceConfig.AppendListFields) | \ - frozenset(SourceConfig.ExtendDictFields) | \ - frozenset(SourceConfig.ClobberFields) - - self.assertEqual(used_keys, frozenset(SourceConfig.ConfigKeys)) - - def test__source_config_wrong_type(self): - # ExtendDict fields should handle non-dict inputs gracefully - self.root_project_config.models = {'persist_docs': False} - cfg = SourceConfig(self.root_project_config, self.root_project_config, - ['root', 'x'], NodeType.Model) - - with self.assertRaises(dbt.exceptions.CompilationException) as exc: - cfg.get_project_config(self.root_project_config) - - self.assertIn('must be a dict', str(exc.exception)) + def file_block_for(self, data: str, filename: str, searched: str): + root_dir = get_abs_os_path('./dbt_modules/snowplow') + filename = normalize(filename) + path = FilePath( + searched_path=searched, + relative_path=filename, + absolute_path=os.path.normpath(os.path.abspath( + os.path.join(root_dir, searched, filename) + )), + ) + source_file = SourceFile( + path=path, + checksum=FileHash.from_contents(data), + ) + source_file.contents = data + return FileBlock(file=source_file) + + def assert_has_results_length(self, results, files=1, macros=0, nodes=0, + sources=0, docs=0, patches=0, disabled=0): + self.assertEqual(len(results.files), files) + self.assertEqual(len(results.macros), macros) + self.assertEqual(len(results.nodes), nodes) + self.assertEqual(len(results.sources), sources) + self.assertEqual(len(results.docs), docs) + self.assertEqual(len(results.patches), patches) + self.assertEqual(sum(len(v) for v in results.disabled.values()), disabled) + + +SINGLE_TABLE_SOURCE = ''' +version: 2 +sources: + - name: my_source + tables: + - name: my_table +''' + +SINGLE_TABLE_SOURCE_TESTS = ''' +version: 2 +sources: + - name: my_source + tables: + - name: my_table + description: A description of my table + columns: + - name: color + tests: + - not_null: + severity: WARN + - accepted_values: + values: ['red', 'blue', 'green'] +''' + + +SINGLE_TABLE_MODEL_TESTS = ''' +version: 2 +models: + - name: my_model + description: A description of my model + columns: + - name: color + description: The color value + tests: + - not_null: + severity: WARN + - accepted_values: + values: ['red', 'blue', 'green'] + - foreign_package.test_case: + arg: 100 +''' class SchemaParserTest(BaseParserTest): - maxDiff = None - def setUp(self): super().setUp() - self.maxDiff = None - - self.macro_manifest = Manifest(macros={}, nodes={}, docs={}, - generated_at=datetime.utcnow(), - disabled=[]) - - self.model_config = NodeConfig.from_dict({ - 'enabled': True, - 'materialized': 'view', - 'persist_docs': {}, - 'post-hook': [], - 'pre-hook': [], - 'vars': {}, - 'quoting': {}, - 'column_types': {}, - 'tags': [], - }) - - self.test_config = TestConfig.from_dict({ - 'enabled': True, - 'materialized': 'view', - 'persist_docs': {}, - 'post-hook': [], - 'pre-hook': [], - 'vars': {}, - 'quoting': {}, - 'column_types': {}, - 'tags': [], - 'severity': 'ERROR', - }) - self.warn_test_config = self.test_config.replace(severity='WARN') - - self.disabled_config = { - 'enabled': False, - 'materialized': 'view', - 'post-hook': [], - 'pre-hook': [], - 'vars': {}, - 'quoting': {}, - 'column_types': {}, - 'tags': [], - } - - self._expected_source = ParsedSourceDefinition( - unique_id='source.root.my_source.my_table', - name='my_table', - description='my table description', + self.parser = SchemaParser( + results=ParseResult.rpc(), + project=self.snowplow_project_config, + root_project=self.root_project_config, + macro_manifest=self.macro_manifest, + ) + + def file_block_for(self, data, filename): + return super().file_block_for(data, filename, 'models') + + def yaml_block_for(self, test_yml: str, filename: str): + file_block = self.file_block_for(data=test_yml, filename=filename) + return YamlBlock.from_file_block( + src=file_block, + data=yaml.safe_load(test_yml), + ) + + +class SchemaParserSourceTest(SchemaParserTest): + def test__read_basic_source(self): + block = self.yaml_block_for(SINGLE_TABLE_SOURCE, 'test_one.yml') + self.assertEqual(len(list(self.parser.read_yaml_models(yaml=block))), 0) + results = list(self.parser.read_yaml_sources(yaml=block)) + self.assertEqual(len(results), 1) + self.assertEqual(results[0].source.name, 'my_source') + self.assertEqual(results[0].table.name, 'my_table') + self.assertEqual(results[0].table.description, '') + self.assertEqual(len(results[0].tests), 0) + self.assertEqual(len(results[0].columns), 0) + + def test__parse_basic_source(self): + block = self.file_block_for(SINGLE_TABLE_SOURCE, 'test_one.yml') + self.parser.parse_file(block) + # self.parser.parse_yaml_sources(yaml_block=block) + self.assert_has_results_length(self.parser.results, sources=1) + src = list(self.parser.results.sources.values())[0] + expected = ParsedSourceDefinition( + package_name='snowplow', source_name='my_source', - source_description='my source description', - loader='some_loader', - package_name='root', - root_path=get_os_path('/usr/src/app'), - path='test_one.yml', - original_file_path='test_one.yml', - columns={ - 'id': ColumnInfo(name='id', description='user ID'), - }, + schema='my_source', + name='my_table', + loader='', + freshness=FreshnessThreshold(), + source_description='', + identifier='my_table', + fqn=['snowplow', 'my_source', 'my_table'], + database='test', + unique_id='source.snowplow.my_source.my_table', + root_path=get_abs_os_path('./dbt_modules/snowplow'), + path=normalize('models/test_one.yml'), + original_file_path=normalize('models/test_one.yml'), + resource_type=SourceType.Source, + ) + self.assertEqual(src, expected) + + def test__read_basic_source_tests(self): + block = self.yaml_block_for(SINGLE_TABLE_SOURCE_TESTS, 'test_one.yml') + self.assertEqual(len(list(self.parser.read_yaml_models(yaml=block))), 0) + results = list(self.parser.read_yaml_sources(yaml=block)) + self.assertEqual(len(results), 1) + self.assertEqual(results[0].source.name, 'my_source') + self.assertEqual(results[0].table.name, 'my_table') + self.assertEqual(results[0].table.description, 'A description of my table') + self.assertEqual(len(results[0].columns), 1) + self.assertEqual(len(results[0].columns[0].tests), 2) + self.assertEqual(len(results[0].tests), 0) + + def test__parse_basic_source_tests(self): + block = self.file_block_for(SINGLE_TABLE_SOURCE_TESTS, 'test_one.yml') + self.parser.parse_file(block) + self.assertEqual(len(self.parser.results.nodes), 2) + self.assertEqual(len(self.parser.results.sources), 1) + self.assertEqual(len(self.parser.results.patches), 0) + src = list(self.parser.results.sources.values())[0] + self.assertEqual(src.source_name, 'my_source') + self.assertEqual(src.schema, 'my_source') + self.assertEqual(src.name, 'my_table') + self.assertEqual(src.description, 'A description of my table') + + tests = sorted(self.parser.results.nodes.values(), key=lambda n: n.unique_id) + + self.assertEqual(tests[0].config.severity, 'ERROR') + self.assertEqual(tests[0].tags, ['schema']) + self.assertEqual(tests[0].sources, [['my_source', 'my_table']]) + self.assertEqual(tests[0].column_name, 'color') + self.assertEqual(tests[0].fqn, ['snowplow', 'schema_test', tests[0].name]) + self.assertEqual(tests[1].config.severity, 'WARN') + self.assertEqual(tests[1].tags, ['schema']) + self.assertEqual(tests[1].sources, [['my_source', 'my_table']]) + self.assertEqual(tests[1].column_name, 'color') + self.assertEqual(tests[1].fqn, ['snowplow', 'schema_test', tests[1].name]) + + path = os.path.abspath('./dbt_modules/snowplow/models/test_one.yml') + self.assertIn(path, self.parser.results.files) + self.assertEqual(sorted(self.parser.results.files[path].nodes), + [t.unique_id for t in tests]) + self.assertIn(path, self.parser.results.files) + self.assertEqual(self.parser.results.files[path].sources, + ['source.snowplow.my_source.my_table']) + + +class SchemaParserModelsTest(SchemaParserTest): + def test__read_basic_model_tests(self): + block = self.yaml_block_for(SINGLE_TABLE_MODEL_TESTS, 'test_one.yml') + self.assertEqual(len(list(self.parser.read_yaml_sources(yaml=block))), 0) + results = list(self.parser.read_yaml_models(yaml=block)) + self.assertEqual(len(results), 1) + self.assertEqual(results[0].name, 'my_model') + self.assertEqual(len(results[0].columns), 1) + self.assertEqual(len(results[0].columns[0].tests), 3) + self.assertEqual(len(results[0].tests), 0) + + def test__parse_basic_model_tests(self): + block = self.file_block_for(SINGLE_TABLE_MODEL_TESTS, 'test_one.yml') + self.parser.parse_file(block) + self.assert_has_results_length(self.parser.results, patches=1, nodes=3) + + patch = list(self.parser.results.patches.values())[0] + self.assertEqual(len(patch.columns), 1) + self.assertEqual(patch.name, 'my_model') + self.assertEqual(patch.description, 'A description of my model') + expected_patch = ParsedNodePatch( + name='my_model', + description='A description of my model', + columns={'color': ColumnInfo(name='color', description='The color value')}, docrefs=[], - freshness=FreshnessThreshold( - warn_after=Time(count=7, period=TimePeriod.hour), - error_after=Time(count=20, period=TimePeriod.hour) - ), - loaded_at_field='something', + original_file_path=normalize('models/test_one.yml'), + ) + self.assertEqual(patch, expected_patch) + + tests = sorted(self.parser.results.nodes.values(), key=lambda n: n.unique_id) + self.assertEqual(tests[0].config.severity, 'ERROR') + self.assertEqual(tests[0].tags, ['schema']) + self.assertEqual(tests[0].refs, [['my_model']]) + self.assertEqual(tests[0].column_name, 'color') + self.assertEqual(tests[0].package_name, 'snowplow') + self.assertTrue(tests[0].name.startswith('accepted_values_')) + self.assertEqual(tests[0].fqn, ['snowplow', 'schema_test', tests[0].name]) + self.assertEqual(tests[0].unique_id.split('.'), ['test', 'snowplow', tests[0].name]) + + # foreign packages are a bit weird, they include the macro package + # name in the test name + self.assertEqual(tests[1].config.severity, 'ERROR') + self.assertEqual(tests[1].tags, ['schema']) + self.assertEqual(tests[1].refs, [['my_model']]) + self.assertEqual(tests[1].column_name, 'color') + self.assertEqual(tests[1].column_name, 'color') + self.assertEqual(tests[1].fqn, ['snowplow', 'schema_test', tests[1].name]) + self.assertTrue(tests[1].name.startswith('foreign_package_test_case_')) + self.assertEqual(tests[1].package_name, 'snowplow') + self.assertEqual(tests[1].unique_id.split('.'), ['test', 'snowplow', tests[1].name]) + + self.assertEqual(tests[2].config.severity, 'WARN') + self.assertEqual(tests[2].tags, ['schema']) + self.assertEqual(tests[2].refs, [['my_model']]) + self.assertEqual(tests[2].column_name, 'color') + self.assertEqual(tests[2].package_name, 'snowplow') + self.assertTrue(tests[2].name.startswith('not_null_')) + self.assertEqual(tests[2].fqn, ['snowplow', 'schema_test', tests[2].name]) + self.assertEqual(tests[2].unique_id.split('.'), ['test', 'snowplow', tests[2].name]) + + path = os.path.abspath('./dbt_modules/snowplow/models/test_one.yml') + self.assertIn(path, self.parser.results.files) + self.assertEqual(sorted(self.parser.results.files[path].nodes), + [t.unique_id for t in tests]) + self.assertIn(path, self.parser.results.files) + self.assertEqual(self.parser.results.files[path].patches, ['my_model']) + + +class ModelParserTest(BaseParserTest): + def setUp(self): + super().setUp() + self.parser = ModelParser( + results=ParseResult.rpc(), + project=self.snowplow_project_config, + root_project=self.root_project_config, + macro_manifest=self.macro_manifest, + ) + + def file_block_for(self, data, filename): + return super().file_block_for(data, filename, 'models') + + def test_basic(self): + raw_sql = '{{ config(materialized="table") }}select 1 as id' + block = self.file_block_for(raw_sql, 'nested/model_1.sql') + self.parser.parse_file(block) + self.assert_has_results_length(self.parser.results, nodes=1) + node = list(self.parser.results.nodes.values())[0] + expected = ParsedModelNode( + alias='model_1', + name='model_1', database='test', - schema='foo', - identifier='bar', - resource_type=NodeType.Source, - quoting=Quoting(schema=True, identifier=False), - fqn=['root', 'my_source', 'my_table'] - ) - - self._expected_source_tests = [ - ParsedTestNode( - alias='source_accepted_values_my_source_my_table_id__a__b', - name='source_accepted_values_my_source_my_table_id__a__b', - database='test', - schema='analytics', - resource_type=NodeType.Test, - unique_id='test.root.source_accepted_values_my_source_my_table_id__a__b', - fqn=['root', 'schema_test', - 'source_accepted_values_my_source_my_table_id__a__b'], - package_name='root', - original_file_path='test_one.yml', - root_path=get_os_path('/usr/src/app'), - refs=[], - sources=[['my_source', 'my_table']], - depends_on=DependsOn(), - config=self.test_config, - path=get_os_path( - 'schema_test/source_accepted_values_my_source_my_table_id__a__b.sql'), - tags=['schema'], - raw_sql="{{ config(severity='ERROR') }}{{ test_accepted_values(model=source('my_source', 'my_table'), column_name='id', values=['a', 'b']) }}", - description='', - columns={}, - column_name='id', - ), - ParsedTestNode( - alias='source_not_null_my_source_my_table_id', - name='source_not_null_my_source_my_table_id', - database='test', - schema='analytics', - resource_type=NodeType.Test, - unique_id='test.root.source_not_null_my_source_my_table_id', - fqn=['root', 'schema_test', 'source_not_null_my_source_my_table_id'], - package_name='root', - root_path=get_os_path('/usr/src/app'), - refs=[], - sources=[['my_source', 'my_table']], - depends_on=DependsOn(), - config=self.test_config, - original_file_path='test_one.yml', - path=get_os_path('schema_test/source_not_null_my_source_my_table_id.sql'), - tags=['schema'], - raw_sql="{{ config(severity='ERROR') }}{{ test_not_null(model=source('my_source', 'my_table'), column_name='id') }}", - description='', - columns={}, - column_name='id', - ), - ParsedTestNode( - alias='source_relationships_my_source_my_table_id__id__ref_model_two_', - name='source_relationships_my_source_my_table_id__id__ref_model_two_', - database='test', - schema='analytics', - resource_type=NodeType.Test, - unique_id='test.root.source_relationships_my_source_my_table_id__id__ref_model_two_', # noqa - fqn=['root', 'schema_test', - 'source_relationships_my_source_my_table_id__id__ref_model_two_'], - package_name='root', - original_file_path='test_one.yml', - root_path=get_os_path('/usr/src/app'), - refs=[['model_two']], - sources=[['my_source', 'my_table']], - depends_on=DependsOn(), - config=self.test_config, - path=get_os_path('schema_test/source_relationships_my_source_my_table_id__id__ref_model_two_.sql'), # noqa - tags=['schema'], - raw_sql="{{ config(severity='ERROR') }}{{ test_relationships(model=source('my_source', 'my_table'), column_name='id', from='id', to=ref('model_two')) }}", - description='', - columns={}, - column_name='id', - ), - ParsedTestNode( - alias='source_some_test_my_source_my_table_value', - name='source_some_test_my_source_my_table_value', - database='test', - schema='analytics', - resource_type=NodeType.Test, - unique_id='test.snowplow.source_some_test_my_source_my_table_value', - fqn=['snowplow', 'schema_test', 'source_some_test_my_source_my_table_value'], - package_name='snowplow', - original_file_path='test_one.yml', - root_path=get_os_path('/usr/src/app'), - refs=[], - sources=[['my_source', 'my_table']], - depends_on=DependsOn(), - config=self.warn_test_config, - path=get_os_path('schema_test/source_some_test_my_source_my_table_value.sql'), - tags=['schema'], - raw_sql="{{ config(severity='WARN') }}{{ snowplow.test_some_test(model=source('my_source', 'my_table'), key='value') }}", - description='', - columns={}, - ), - ParsedTestNode( - alias='source_unique_my_source_my_table_id', - name='source_unique_my_source_my_table_id', - database='test', - schema='analytics', - resource_type=NodeType.Test, - unique_id='test.root.source_unique_my_source_my_table_id', - fqn=['root', 'schema_test', 'source_unique_my_source_my_table_id'], - package_name='root', - root_path=get_os_path('/usr/src/app'), - refs=[], - sources=[['my_source', 'my_table']], - depends_on=DependsOn(), - config=self.warn_test_config, - original_file_path='test_one.yml', - path=get_os_path('schema_test/source_unique_my_source_my_table_id.sql'), - tags=['schema'], - raw_sql="{{ config(severity='WARN') }}{{ test_unique(model=source('my_source', 'my_table'), column_name='id') }}", - description='', - columns={}, - column_name='id', - ), - ] - - self._expected_model_tests = [ - ParsedTestNode( - alias='accepted_values_model_one_id__a__b', - name='accepted_values_model_one_id__a__b', - database='test', - schema='analytics', - resource_type=NodeType.Test, - unique_id='test.root.accepted_values_model_one_id__a__b', - fqn=['root', 'schema_test', - 'accepted_values_model_one_id__a__b'], - package_name='root', - original_file_path='test_one.yml', - root_path=get_os_path('/usr/src/app'), - refs=[['model_one']], - sources=[], - depends_on=DependsOn(), - config=self.test_config, - path=get_os_path( - 'schema_test/accepted_values_model_one_id__a__b.sql'), - tags=['schema'], - raw_sql="{{ config(severity='ERROR') }}{{ test_accepted_values(model=ref('model_one'), column_name='id', values=['a', 'b']) }}", - description='', - columns={}, - column_name='id', - ), - ParsedTestNode( - alias='not_null_model_one_id', - name='not_null_model_one_id', - database='test', - schema='analytics', - resource_type=NodeType.Test, - unique_id='test.root.not_null_model_one_id', - fqn=['root', 'schema_test', 'not_null_model_one_id'], - package_name='root', - root_path=get_os_path('/usr/src/app'), - refs=[['model_one']], - sources=[], - depends_on=DependsOn(), - config=self.test_config, - original_file_path='test_one.yml', - path=get_os_path('schema_test/not_null_model_one_id.sql'), - tags=['schema'], - raw_sql="{{ config(severity='ERROR') }}{{ test_not_null(model=ref('model_one'), column_name='id') }}", - description='', - columns={}, - column_name='id', - ), - ParsedTestNode( - alias='relationships_model_one_id__id__ref_model_two_', - name='relationships_model_one_id__id__ref_model_two_', - database='test', - schema='analytics', - resource_type=NodeType.Test, - unique_id='test.root.relationships_model_one_id__id__ref_model_two_', # noqa - fqn=['root', 'schema_test', - 'relationships_model_one_id__id__ref_model_two_'], - package_name='root', - original_file_path='test_one.yml', - root_path=get_os_path('/usr/src/app'), - refs=[['model_one'], ['model_two']], - sources=[], - depends_on=DependsOn(), - config=self.test_config, - path=get_os_path('schema_test/relationships_model_one_id__id__ref_model_two_.sql'), # noqa - tags=['schema'], - raw_sql="{{ config(severity='ERROR') }}{{ test_relationships(model=ref('model_one'), column_name='id', from='id', to=ref('model_two')) }}", - description='', - columns={}, - column_name='id', + schema='analytics', + resource_type=NodeType.Model, + unique_id='model.snowplow.model_1', + fqn=['snowplow', 'nested', 'model_1'], + package_name='snowplow', + original_file_path=normalize('models/nested/model_1.sql'), + root_path=get_abs_os_path('./dbt_modules/snowplow'), + config=NodeConfig(materialized='table'), + path=normalize('nested/model_1.sql'), + raw_sql=raw_sql, + ) + self.assertEqual(node, expected) + path = os.path.abspath('./dbt_modules/snowplow/models/nested/model_1.sql') + self.assertIn(path, self.parser.results.files) + self.assertEqual(self.parser.results.files[path].nodes, ['model.snowplow.model_1']) + + def test_parse_error(self): + block = self.file_block_for('{{ SYNTAX ERROR }}', 'nested/model_1.sql') + with self.assertRaises(CompilationException): + self.parser.parse_file(block) + self.assert_has_results_length(self.parser.results, files=0) + + +class SnapshotParserTest(BaseParserTest): + def setUp(self): + super().setUp() + self.parser = SnapshotParser( + results=ParseResult.rpc(), + project=self.snowplow_project_config, + root_project=self.root_project_config, + macro_manifest=self.macro_manifest, + ) + + def file_block_for(self, data, filename): + return super().file_block_for(data, filename, 'snapshots') + + def test_parse_error(self): + block = self.file_block_for('{% snapshot foo %}select 1 as id{%snapshot bar %}{% endsnapshot %}', 'nested/snap_1.sql') + with self.assertRaises(CompilationException): + self.parser.parse_file(block) + self.assert_has_results_length(self.parser.results, files=0) + + def test_single_block(self): + raw_sql = '''{{ + config(unique_key="id", target_schema="analytics", + target_database="dbt", strategy="timestamp", + updated_at="last_update") + }} + select 1 as id, now() as last_update''' + full_file = ''' + {{% snapshot foo %}}{}{{% endsnapshot %}} + '''.format(raw_sql) + block = self.file_block_for(full_file, 'nested/snap_1.sql') + self.parser.parse_file(block) + self.assert_has_results_length(self.parser.results, nodes=1) + node = list(self.parser.results.nodes.values())[0] + expected = ParsedSnapshotNode( + alias='foo', + name='foo', + # the `database` entry is overrridden by the target_database config + database='dbt', + schema='analytics', + resource_type=SnapshotType.Snapshot, + unique_id='snapshot.snowplow.foo', + fqn=['snowplow', 'nested', 'snap_1', 'foo'], + package_name='snowplow', + original_file_path=normalize('snapshots/nested/snap_1.sql'), + root_path=get_abs_os_path('./dbt_modules/snowplow'), + config=TimestampSnapshotConfig( + strategy=TimestampStrategy.Timestamp, + updated_at='last_update', + target_database='dbt', + target_schema='analytics', + unique_key='id', + materialized='snapshot', ), - ParsedTestNode( - alias='some_test_model_one_value', - name='some_test_model_one_value', - database='test', - schema='analytics', - resource_type=NodeType.Test, - unique_id='test.snowplow.some_test_model_one_value', - fqn=['snowplow', 'schema_test', 'some_test_model_one_value'], - package_name='snowplow', - original_file_path='test_one.yml', - root_path=get_os_path('/usr/src/app'), - refs=[['model_one']], - sources=[], - depends_on=DependsOn(), - config=self.warn_test_config, - path=get_os_path('schema_test/some_test_model_one_value.sql'), - tags=['schema'], - raw_sql="{{ config(severity='WARN') }}{{ snowplow.test_some_test(model=ref('model_one'), key='value') }}", - description='', - columns={}, + path=normalize('nested/snap_1.sql'), + raw_sql=raw_sql, + ) + self.assertEqual(node, expected) + path = os.path.abspath('./dbt_modules/snowplow/snapshots/nested/snap_1.sql') + self.assertIn(path, self.parser.results.files) + self.assertEqual(self.parser.results.files[path].nodes, ['snapshot.snowplow.foo']) + + def test_multi_block(self): + raw_1 = ''' + {{ + config(unique_key="id", target_schema="analytics", + target_database="dbt", strategy="timestamp", + updated_at="last_update") + }} + select 1 as id, now() as last_update + ''' + raw_2 = ''' + {{ + config(unique_key="id", target_schema="analytics", + target_database="dbt", strategy="timestamp", + updated_at="last_update") + }} + select 2 as id, now() as last_update + ''' + full_file = ''' + {{% snapshot foo %}}{}{{% endsnapshot %}} + {{% snapshot bar %}}{}{{% endsnapshot %}} + '''.format(raw_1, raw_2) + block = self.file_block_for(full_file, 'nested/snap_1.sql') + self.parser.parse_file(block) + self.assert_has_results_length(self.parser.results, nodes=2) + nodes = sorted(self.parser.results.nodes.values(), key=lambda n: n.name) + expect_foo = ParsedSnapshotNode( + alias='foo', + name='foo', + database='dbt', + schema='analytics', + resource_type=SnapshotType.Snapshot, + unique_id='snapshot.snowplow.foo', + fqn=['snowplow', 'nested', 'snap_1', 'foo'], + package_name='snowplow', + original_file_path=normalize('snapshots/nested/snap_1.sql'), + root_path=get_abs_os_path('./dbt_modules/snowplow'), + config=TimestampSnapshotConfig( + strategy=TimestampStrategy.Timestamp, + updated_at='last_update', + target_database='dbt', + target_schema='analytics', + unique_key='id', + materialized='snapshot', ), - ParsedTestNode( - alias='unique_model_one_id', - name='unique_model_one_id', - database='test', - schema='analytics', - resource_type=NodeType.Test, - unique_id='test.root.unique_model_one_id', - fqn=['root', 'schema_test', 'unique_model_one_id'], - package_name='root', - root_path=get_os_path('/usr/src/app'), - refs=[['model_one']], - sources=[], - depends_on=DependsOn(), - config=self.warn_test_config, - original_file_path='test_one.yml', - path=get_os_path('schema_test/unique_model_one_id.sql'), - tags=['schema'], - raw_sql="{{ config(severity='WARN') }}{{ test_unique(model=ref('model_one'), column_name='id') }}", - description='', - columns={}, - column_name='id', + path=normalize('nested/snap_1.sql'), + raw_sql=raw_1, + ) + expect_bar = ParsedSnapshotNode( + alias='bar', + name='bar', + database='dbt', + schema='analytics', + resource_type=SnapshotType.Snapshot, + unique_id='snapshot.snowplow.bar', + fqn=['snowplow', 'nested', 'snap_1', 'bar'], + package_name='snowplow', + original_file_path=normalize('snapshots/nested/snap_1.sql'), + root_path=get_abs_os_path('./dbt_modules/snowplow'), + config=TimestampSnapshotConfig( + strategy=TimestampStrategy.Timestamp, + updated_at='last_update', + target_database='dbt', + target_schema='analytics', + unique_key='id', + materialized='snapshot', ), - ] - - self._expected_patch = ParsedNodePatch( - name='model_one', - description='blah blah', - original_file_path='test_one.yml', - columns={ - 'id': ColumnInfo(name='id', description='user ID'), - }, - docrefs=[], + path=normalize('nested/snap_1.sql'), + raw_sql=raw_2, ) + self.assertEqual(nodes[0], expect_bar) + self.assertEqual(nodes[1], expect_foo) + path = os.path.abspath('./dbt_modules/snowplow/snapshots/nested/snap_1.sql') + self.assertIn(path, self.parser.results.files) + self.assertEqual(sorted(self.parser.results.files[path].nodes), + ['snapshot.snowplow.bar', 'snapshot.snowplow.foo']) - def test__source_schema(self): - test_yml = yaml.safe_load(''' - version: 2 - sources: - - name: my_source - loader: some_loader - description: my source description - quoting: - schema: True - identifier: True - freshness: - warn_after: - count: 10 - period: hour - error_after: - count: 20 - period: hour - loaded_at_field: something - schema: '{{ var("test_schema_name") }}' - tables: - - name: my_table - description: "my table description" - identifier: bar - freshness: - warn_after: - count: 7 - period: hour - quoting: - identifier: False - columns: - - name: id - description: user ID - tests: - - unique: - severity: WARN - - not_null - - accepted_values: - values: - - a - - b - - relationships: - from: id - to: ref('model_two') - tests: - - snowplow.some_test: - key: value - severity: WARN - ''') - parser = SchemaParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - root_dir = get_os_path('/usr/src/app') - results = list(parser.parse_schema( - path='test_one.yml', - test_yml=test_yml, - package_name='root', - root_dir=root_dir - )) - - tests = sorted((node for t, node in results if t == 'test'), - key=lambda n: n.name) - patches = sorted((node for t, node in results if t == 'patch'), - key=lambda n: n.name) - sources = sorted((node for t, node in results if t == 'source'), - key=lambda n: n.name) - self.assertEqual(len(tests), 5) - self.assertEqual(len(patches), 0) - self.assertEqual(len(sources), 1) - self.assertEqual(len(results), 6) - - for test, expected in zip(tests, self._expected_source_tests): - self.assertEqual(test, expected) - - self.assertEqual(sources[0], self._expected_source) - - def test__model_schema(self): - test_yml = yaml.safe_load(''' - version: 2 - models: - - name: model_one - description: blah blah - columns: - - name: id - description: user ID - tests: - - unique: - severity: WARN - - not_null - - accepted_values: - values: - - a - - b - - relationships: - from: id - to: ref('model_two') - tests: - - snowplow.some_test: - severity: WARN - key: value - ''') - parser = SchemaParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - results = list(parser.parse_schema( - path='test_one.yml', - test_yml=test_yml, - package_name='root', - root_dir=get_os_path('/usr/src/app') - )) - - tests = sorted((node for t, node in results if t == 'test'), - key=lambda n: n.name) - patches = sorted((node for t, node in results if t == 'patch'), - key=lambda n: n.name) - sources = sorted((node for t, node in results if t == 'source'), - key=lambda n: n.name) - self.assertEqual(len(tests), 5) - self.assertEqual(len(patches), 1) - self.assertEqual(len(sources), 0) - self.assertEqual(len(results), 6) - - for test, expected in zip(tests, self._expected_model_tests): - self.assertEqual(test, expected) - - - self.assertEqual(patches[0], self._expected_patch) - - def test__mixed_schema(self): - test_yml = yaml.safe_load(''' - version: 2 - quoting: - database: True - models: - - name: model_one - description: blah blah - columns: - - name: id - description: user ID - tests: - - unique: - severity: WARN - - not_null - - accepted_values: - values: - - a - - b - - relationships: - from: id - to: ref('model_two') - tests: - - snowplow.some_test: - severity: WARN - key: value - sources: - - name: my_source - loader: some_loader - description: my source description - quoting: - schema: True - identifier: True - freshness: - warn_after: - count: 10 - period: hour - error_after: - count: 20 - period: hour - loaded_at_field: something - schema: '{{ var("test_schema_name") }}' - tables: - - name: my_table - description: "my table description" - identifier: bar - freshness: - warn_after: - count: 7 - period: hour - quoting: - identifier: False - columns: - - name: id - description: user ID - tests: - - unique: - severity: WARN - - not_null - - accepted_values: - values: - - a - - b - - relationships: - from: id - to: ref('model_two') - tests: - - snowplow.some_test: - severity: WARN - key: value - ''') - parser = SchemaParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - results = list(parser.parse_schema( - path='test_one.yml', - test_yml=test_yml, - package_name='root', - root_dir=get_os_path('/usr/src/app') - )) - - tests = sorted((node for t, node in results if t == 'test'), - key=lambda n: n.name) - patches = sorted((node for t, node in results if t == 'patch'), - key=lambda n: n.name) - sources = sorted((node for t, node in results if t == 'source'), - key=lambda n: n.name) - self.assertEqual(len(tests), 10) - self.assertEqual(len(patches), 1) - self.assertEqual(len(sources), 1) - self.assertEqual(len(results), 12) - - expected_tests = self._expected_model_tests + self._expected_source_tests - expected_tests.sort(key=lambda n: n.name) - for test, expected in zip(tests, expected_tests): - self.assertEqual(test, expected) - - self.assertEqual(patches[0], self._expected_patch) - self.assertEqual(sources[0], self._expected_source) - - def test__source_schema_invalid_test_strict(self): - test_yml = yaml.safe_load(''' - version: 2 - sources: - - name: my_source - loader: some_loader - description: my source description - quoting: - schema: True - identifier: True - freshness: - warn_after: - count: 10 - period: hour - error_after: - count: 20 - period: hour - loaded_at_field: something - schema: foo - tables: - - name: my_table - description: "my table description" - identifier: bar - freshness: - warn_after: - count: 7 - period: hour - quoting: - identifier: False - columns: - - name: id - description: user ID - tests: - - unique: - severity: WARN - - not_null - - accepted_values: # this test is invalid - - values: - - a - - b - - relationships: - from: id - to: ref('model_two') - tests: - - snowplow.some_test: - severity: WARN - key: value - ''') - parser = SchemaParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - root_dir = get_os_path('/usr/src/app') - with self.assertRaises(dbt.exceptions.CompilationException): - list(parser.parse_schema( - path='test_one.yml', - test_yml=test_yml, - package_name='root', - root_dir=root_dir - )) - - def test__source_schema_invalid_test_not_strict(self): - dbt.flags.WARN_ERROR = False - dbt.flags.STRICT_MODE = False - test_yml = yaml.safe_load(''' - version: 2 - sources: - - name: my_source - loader: some_loader - description: my source description - quoting: - schema: True - identifier: True - freshness: - warn_after: - count: 10 - period: hour - error_after: - count: 20 - period: hour - loaded_at_field: something - schema: foo - tables: - - name: my_table - description: "my table description" - identifier: bar - freshness: - warn_after: - count: 7 - period: hour - quoting: - identifier: False - columns: - - name: id - description: user ID - tests: - - unique: - severity: WARN - - not_null - - accepted_values: # this test is invalid - - values: - - a - - b - - relationships: - from: id - to: ref('model_two') - tests: - - snowplow.some_test: - severity: WARN - key: value - ''') - parser = SchemaParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - root_dir = get_os_path('/usr/src/app') - results = list(parser.parse_schema( - path='test_one.yml', - test_yml=test_yml, - package_name='root', - root_dir=root_dir - )) - - tests = sorted((node for t, node in results if t == 'test'), - key=lambda n: n.name) - patches = sorted((node for t, node in results if t == 'patch'), - key=lambda n: n.name) - sources = sorted((node for t, node in results if t == 'source'), - key=lambda n: n.name) - self.assertEqual(len(tests), 4) - self.assertEqual(len(patches), 0) - self.assertEqual(len(sources), 1) - self.assertEqual(len(results), 5) - - expected_tests = [x for x in self._expected_source_tests - if 'accepted_values' not in x.unique_id] - for test, expected in zip(tests, expected_tests): - self.assertEqual(test, expected) - - self.assertEqual(sources[0], self._expected_source) - - @mock.patch.object(SchemaParser, 'find_schema_yml') - @mock.patch.object(dbt.parser.schemas, 'logger') - def test__schema_v2_as_v1(self, mock_logger, find_schema_yml): - test_yml = yaml.safe_load( - '{models: [{name: model_one, description: "blah blah", columns: [' - '{name: id, description: "user ID", tests: [unique, not_null, ' - '{accepted_values: {values: ["a", "b"]}},' - '{relationships: {from: id, to: ref(\'model_two\')}}]' - '}], tests: [some_test: { key: value }]}]}' - ) - find_schema_yml.return_value = [('/some/path/schema.yml', test_yml)] - root_project = {} - all_projects = {} - root_dir = '/some/path' - relative_dirs = ['a', 'b'] - parser = dbt.parser.schemas.SchemaParser(root_project, all_projects, None) - with self.assertRaises(dbt.exceptions.CompilationException) as cm: - parser.load_and_parse( - 'test', root_dir, relative_dirs - ) - self.assertIn('https://docs.getdbt.com/docs/schemayml-files', - str(cm.exception)) - - @mock.patch.object(SchemaParser, 'find_schema_yml') - @mock.patch.object(dbt.parser.schemas, 'logger') - def test__schema_v1_version_model(self, mock_logger, find_schema_yml): - test_yml = yaml.safe_load( - '{model_one: {constraints: {not_null: [id],' - 'unique: [id],' - 'accepted_values: [{field: id, values: ["a","b"]}],' - 'relationships: [{from: id, to: ref(\'model_two\'), field: id}]' # noqa - '}}, version: {constraints: {not_null: [id]}}}' - ) - find_schema_yml.return_value = [('/some/path/schema.yml', test_yml)] - root_project = {} - all_projects = {} - root_dir = '/some/path' - relative_dirs = ['a', 'b'] - parser = dbt.parser.schemas.SchemaParser(root_project, all_projects, None) - with self.assertRaises(dbt.exceptions.CompilationException) as cm: - parser.load_and_parse( - 'test', root_dir, relative_dirs - ) - self.assertIn('https://docs.getdbt.com/docs/schemayml-files', - str(cm.exception)) - - @mock.patch.object(SchemaParser, 'find_schema_yml') - @mock.patch.object(dbt.parser.schemas, 'logger') - def test__schema_v1_version_1(self, mock_logger, find_schema_yml): - test_yml = yaml.safe_load( - '{model_one: {constraints: {not_null: [id],' - 'unique: [id],' - 'accepted_values: [{field: id, values: ["a","b"]}],' - 'relationships: [{from: id, to: ref(\'model_two\'), field: id}]' # noqa - '}}, version: 1}' - ) - find_schema_yml.return_value = [('/some/path/schema.yml', test_yml)] - root_project = {} - all_projects = {} - root_dir = '/some/path' - relative_dirs = ['a', 'b'] - parser = dbt.parser.schemas.SchemaParser(root_project, all_projects, None) - with self.assertRaises(dbt.exceptions.CompilationException) as cm: - parser.load_and_parse( - 'test', root_dir, relative_dirs - ) - self.assertIn('https://docs.getdbt.com/docs/schemayml-files', - str(cm.exception)) - - -class ParserTest(BaseParserTest): - def _assert_parsed_sql_nodes(self, parse_result, parsed, disabled): - self.assertEqual(parse_result.parsed, parsed) - self.assertEqual(parse_result.disabled, disabled) - - - def find_input_by_name(self, models, name): - return next( - (model for model in models if model.get('name') == name), - {}) +class MacroParserTest(BaseParserTest): def setUp(self): super().setUp() - - self.macro_manifest = Manifest(macros={}, nodes={}, docs={}, - generated_at=datetime.utcnow(), disabled=[]) - - self.model_config = NodeConfig.from_dict({ - 'enabled': True, - 'materialized': 'view', - 'persist_docs': {}, - 'post-hook': [], - 'pre-hook': [], - 'vars': {}, - 'quoting': {}, - 'column_types': {}, - 'tags': [], - }) - - self.test_config = TestConfig.from_dict({ - 'enabled': True, - 'materialized': 'view', - 'persist_docs': {}, - 'post-hook': [], - 'pre-hook': [], - 'vars': {}, - 'quoting': {}, - 'column_types': {}, - 'tags': [], - 'severity': 'ERROR', - }) - - self.disabled_config = NodeConfig.from_dict({ - 'enabled': False, - 'materialized': 'view', - 'persist_docs': {}, - 'post-hook': [], - 'pre-hook': [], - 'vars': {}, - 'quoting': {}, - 'column_types': {}, - 'tags': [], - }) - - def test__single_model(self): - models = [{ - 'name': 'model_one', - 'resource_type': 'model', - 'package_name': 'root', - 'original_file_path': 'model_one.sql', - 'root_path': get_os_path('/usr/src/app'), - 'path': 'model_one.sql', - 'raw_sql': ("select * from events"), - }] - parser = ModelParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(models), - { - 'model.root.model_one': ParsedModelNode( - alias='model_one', - name='model_one', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.model_one', - fqn=['root', 'model_one'], - package_name='root', - original_file_path='model_one.sql', - root_path=get_os_path('/usr/src/app'), - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='model_one.sql', - raw_sql=self.find_input_by_name( - models, 'model_one').get('raw_sql'), - description='', - columns={} - ) - }, - [] - ) - - def test__single_model__nested_configuration(self): - models = [{ - 'name': 'model_one', - 'resource_type': 'model', - 'package_name': 'root', - 'original_file_path': 'nested/path/model_one.sql', - 'root_path': get_os_path('/usr/src/app'), - 'path': get_os_path('nested/path/model_one.sql'), - 'raw_sql': ("select * from events"), - }] - - self.root_project_config.models = { - 'materialized': 'ephemeral', - 'root': { - 'nested': { - 'path': { - 'materialized': 'ephemeral' - } - } - } - } - - ephemeral_config = self.model_config.replace(materialized='ephemeral') - - parser = ModelParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(models), - { - 'model.root.model_one': ParsedModelNode( - alias='model_one', - name='model_one', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.model_one', - fqn=['root', 'nested', 'path', 'model_one'], - package_name='root', - original_file_path='nested/path/model_one.sql', - root_path=get_os_path('/usr/src/app'), - refs=[], - sources=[], - depends_on=DependsOn(), - config=ephemeral_config, - tags=[], - path=get_os_path('nested/path/model_one.sql'), - raw_sql=self.find_input_by_name( - models, 'model_one').get('raw_sql'), - description='', - columns={} - ) - }, - [] - ) - - def test__empty_model(self): - models = [{ - 'name': 'model_one', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'model_one.sql', - 'original_file_path': 'model_one.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': (" "), - }] - - del self.all_projects['snowplow'] - parser = ModelParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(models), - { - 'model.root.model_one': ParsedModelNode( - alias='model_one', - name='model_one', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.model_one', - fqn=['root', 'model_one'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='model_one.sql', - original_file_path='model_one.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'model_one').get('raw_sql'), - description='', - columns={} - ) - }, - [] - ) - - def test__simple_dependency(self): - models = [{ - 'name': 'base', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'base.sql', - 'original_file_path': 'base.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': 'select * from events' - }, { - 'name': 'events_tx', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'events_tx.sql', - 'original_file_path': 'events_tx.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': "select * from {{ref('base')}}" - }] - - parser = ModelParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(models), - { - 'model.root.base': ParsedModelNode( - alias='base', - name='base', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.base', - fqn=['root', 'base'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='base.sql', - original_file_path='base.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name(models, 'base').get('raw_sql'), - description='', - columns={} - - ), - 'model.root.events_tx': ParsedModelNode( - alias='events_tx', - name='events_tx', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.events_tx', - fqn=['root', 'events_tx'], - package_name='root', - refs=[['base']], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='events_tx.sql', - original_file_path='events_tx.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name(models, 'events_tx').get('raw_sql'), - description='', - columns={} - ) - }, - [] - ) - - def test__multiple_dependencies(self): - models = [{ - 'name': 'events', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'events.sql', - 'original_file_path': 'events.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': 'select * from base.events', - }, { - 'name': 'sessions', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'sessions.sql', - 'original_file_path': 'sessions.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': 'select * from base.sessions', - }, { - 'name': 'events_tx', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'events_tx.sql', - 'original_file_path': 'events_tx.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("with events as (select * from {{ref('events')}}) " - "select * from events"), - }, { - 'name': 'sessions_tx', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'sessions_tx.sql', - 'original_file_path': 'sessions_tx.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("with sessions as (select * from {{ref('sessions')}}) " - "select * from sessions"), - }, { - 'name': 'multi', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'multi.sql', - 'original_file_path': 'multi.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("with s as (select * from {{ref('sessions_tx')}}), " - "e as (select * from {{ref('events_tx')}}) " - "select * from e left join s on s.id = e.sid"), - }] - - parser = ModelParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(models), - { - 'model.root.events': ParsedModelNode( - alias='events', - name='events', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.events', - fqn=['root', 'events'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='events.sql', - original_file_path='events.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'events').get('raw_sql'), - description='', - columns={} - ), - 'model.root.sessions': ParsedModelNode( - alias='sessions', - name='sessions', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.sessions', - fqn=['root', 'sessions'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='sessions.sql', - original_file_path='sessions.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'sessions').get('raw_sql'), - description='', - columns={}, - ), - 'model.root.events_tx': ParsedModelNode( - alias='events_tx', - name='events_tx', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.events_tx', - fqn=['root', 'events_tx'], - package_name='root', - refs=[['events']], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='events_tx.sql', - original_file_path='events_tx.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'events_tx').get('raw_sql'), - description='', - columns={} - ), - 'model.root.sessions_tx': ParsedModelNode( - alias='sessions_tx', - name='sessions_tx', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.sessions_tx', - fqn=['root', 'sessions_tx'], - package_name='root', - refs=[['sessions']], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='sessions_tx.sql', - original_file_path='sessions_tx.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'sessions_tx').get('raw_sql'), - description='', - columns={} - ), - 'model.root.multi': ParsedModelNode( - alias='multi', - name='multi', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.multi', - fqn=['root', 'multi'], - package_name='root', - refs=[['sessions_tx'], ['events_tx']], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='multi.sql', - original_file_path='multi.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'multi').get('raw_sql'), - description='', - columns={} - ), - }, - [] - ) - - def test__multiple_dependencies__packages(self): - models = [{ - 'name': 'events', - 'resource_type': 'model', - 'package_name': 'snowplow', - 'path': 'events.sql', - 'original_file_path': 'events.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': 'select * from base.events', - }, { - 'name': 'sessions', - 'resource_type': 'model', - 'package_name': 'snowplow', - 'path': 'sessions.sql', - 'original_file_path': 'sessions.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': 'select * from base.sessions', - }, { - 'name': 'events_tx', - 'resource_type': 'model', - 'package_name': 'snowplow', - 'path': 'events_tx.sql', - 'original_file_path': 'events_tx.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("with events as (select * from {{ref('events')}}) " - "select * from events"), - }, { - 'name': 'sessions_tx', - 'resource_type': 'model', - 'package_name': 'snowplow', - 'path': 'sessions_tx.sql', - 'original_file_path': 'sessions_tx.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("with sessions as (select * from {{ref('sessions')}}) " - "select * from sessions"), - }, { - 'name': 'multi', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'multi.sql', - 'original_file_path': 'multi.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("with s as " - "(select * from {{ref('snowplow', 'sessions_tx')}}), " - "e as " - "(select * from {{ref('snowplow', 'events_tx')}}) " - "select * from e left join s on s.id = e.sid"), - }] - - parser = ModelParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(models), - { - 'model.snowplow.events': ParsedModelNode( - alias='events', - name='events', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.snowplow.events', - fqn=['snowplow', 'events'], - package_name='snowplow', - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='events.sql', - original_file_path='events.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'events').get('raw_sql'), - description='', - columns={} - ), - 'model.snowplow.sessions': ParsedModelNode( - alias='sessions', - name='sessions', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.snowplow.sessions', - fqn=['snowplow', 'sessions'], - package_name='snowplow', - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='sessions.sql', - original_file_path='sessions.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'sessions').get('raw_sql'), - description='', - columns={} - ), - 'model.snowplow.events_tx': ParsedModelNode( - alias='events_tx', - name='events_tx', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.snowplow.events_tx', - fqn=['snowplow', 'events_tx'], - package_name='snowplow', - refs=[['events']], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='events_tx.sql', - original_file_path='events_tx.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'events_tx').get('raw_sql'), - description='', - columns={} - ), - 'model.snowplow.sessions_tx': ParsedModelNode( - alias='sessions_tx', - name='sessions_tx', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.snowplow.sessions_tx', - fqn=['snowplow', 'sessions_tx'], - package_name='snowplow', - refs=[['sessions']], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='sessions_tx.sql', - original_file_path='sessions_tx.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'sessions_tx').get('raw_sql'), - description='', - columns={} - ), - 'model.root.multi': ParsedModelNode( - alias='multi', - name='multi', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.multi', - fqn=['root', 'multi'], - package_name='root', - refs=[['snowplow', 'sessions_tx'], - ['snowplow', 'events_tx']], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='multi.sql', - original_file_path='multi.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'multi').get('raw_sql'), - description='', - columns={} - ), - }, - [] + self.parser = MacroParser( + results=ParseResult.rpc(), + project=self.snowplow_project_config, + ) + + def file_block_for(self, data, filename): + return super().file_block_for(data, filename, 'macros') + + def test_single_block(self): + raw_sql = '{% macro foo(a, b) %}a ~ b{% endmacro %}' + block = self.file_block_for(raw_sql, 'macro.sql') + self.parser.parse_file(block) + self.assert_has_results_length(self.parser.results, macros=1) + macro = list(self.parser.results.macros.values())[0] + expected = ParsedMacro( + name='foo', + resource_type=MacroType.Macro, + unique_id='macro.snowplow.foo', + package_name='snowplow', + original_file_path=normalize('macros/macro.sql'), + root_path=get_abs_os_path('./dbt_modules/snowplow'), + path=normalize('macros/macro.sql'), + raw_sql=raw_sql + ) + self.assertEqual(macro, expected) + path = os.path.abspath('./dbt_modules/snowplow/macros/macro.sql') + self.assertIn(path, self.parser.results.files) + self.assertEqual(self.parser.results.files[path].macros, ['macro.snowplow.foo']) + + +class DataTestParserTest(BaseParserTest): + def setUp(self): + super().setUp() + self.parser = DataTestParser( + results=ParseResult.rpc(), + project=self.snowplow_project_config, + root_project=self.root_project_config, + macro_manifest=self.macro_manifest, + ) + + def file_block_for(self, data, filename): + return super().file_block_for(data, filename, 'tests') + + def test_basic(self): + raw_sql = 'select * from {{ ref("blah") }} limit 0' + block = self.file_block_for(raw_sql, 'test_1.sql') + self.parser.parse_file(block) + self.assert_has_results_length(self.parser.results, nodes=1) + node = list(self.parser.results.nodes.values())[0] + expected = ParsedTestNode( + alias='test_1', + name='test_1', + database='test', + schema='analytics', + resource_type=TestType.Test, + unique_id='test.snowplow.test_1', + fqn=['snowplow', 'data_test', 'test_1'], + package_name='snowplow', + original_file_path=normalize('tests/test_1.sql'), + root_path=get_abs_os_path('./dbt_modules/snowplow'), + refs=[['blah']], + config=TestConfig(severity='ERROR'), + tags=['data'], + path=normalize('data_test/test_1.sql'), + raw_sql=raw_sql, + ) + self.assertEqual(node, expected) + path = os.path.abspath('./dbt_modules/snowplow/tests/test_1.sql') + self.assertIn(path, self.parser.results.files) + self.assertEqual(self.parser.results.files[path].nodes, ['test.snowplow.test_1']) + + +class AnalysisParserTest(BaseParserTest): + def setUp(self): + super().setUp() + self.parser = AnalysisParser( + results=ParseResult.rpc(), + project=self.snowplow_project_config, + root_project=self.root_project_config, + macro_manifest=self.macro_manifest, + ) + + def file_block_for(self, data, filename): + return super().file_block_for(data, filename, 'analyses') + + def test_basic(self): + raw_sql = 'select 1 as id' + block = self.file_block_for(raw_sql, 'nested/analysis_1.sql') + self.parser.parse_file(block) + self.assert_has_results_length(self.parser.results, nodes=1) + node = list(self.parser.results.nodes.values())[0] + expected = ParsedAnalysisNode( + alias='analysis_1', + name='analysis_1', + database='test', + schema='analytics', + resource_type=AnalysisType.Analysis, + unique_id='analysis.snowplow.analysis_1', + fqn=['snowplow', 'analysis', 'nested', 'analysis_1'], + package_name='snowplow', + original_file_path=normalize('analyses/nested/analysis_1.sql'), + root_path=get_abs_os_path('./dbt_modules/snowplow'), + depends_on=DependsOn(), + config=NodeConfig(), + path=normalize('analysis/nested/analysis_1.sql'), + raw_sql=raw_sql, + ) + self.assertEqual(node, expected) + path = os.path.abspath('./dbt_modules/snowplow/analyses/nested/analysis_1.sql') + self.assertIn(path, self.parser.results.files) + self.assertEqual(self.parser.results.files[path].nodes, ['analysis.snowplow.analysis_1']) + + +class ParserUtilsTest(unittest.TestCase): + def setUp(self): + x_depends_on = mock.MagicMock() + y_depends_on = mock.MagicMock() + x_uid = 'model.project.x' + y_uid = 'model.otherproject.y' + src_uid = 'source.thirdproject.src.tbl' + remote_docref = mock.MagicMock(documentation_package='otherproject', documentation_name='my_doc', column_name=None) + docref = mock.MagicMock(documentation_package='', documentation_name='my_doc', column_name=None) + self.x_node = mock.MagicMock( + refs=[], sources=[['src', 'tbl']], docrefs=[remote_docref], unique_id=x_uid, + resource_type=NodeType.Model, depends_on=x_depends_on, + description='other_project: {{ doc("otherproject", "my_doc") }}', + ) + self.y_node = mock.MagicMock( + refs=[['x']], sources=[], docrefs=[docref], unique_id=y_uid, + resource_type=NodeType.Model, depends_on=y_depends_on, + description='{{ doc("my_doc") }}', + ) + self.src_node = mock.MagicMock( + resource_type=NodeType.Source, unique_id=src_uid, ) - - def test__process_refs__packages(self): nodes = { - 'model.snowplow.events': ParsedModelNode( - name='events', - alias='events', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.snowplow.events', - fqn=['snowplow', 'events'], - package_name='snowplow', - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.disabled_config, - tags=[], - path='events.sql', - original_file_path='events.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql='does not matter', - ), - 'model.root.events': ParsedModelNode( - name='events', - alias='events', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.events', - fqn=['root', 'events'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='events.sql', - original_file_path='events.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql='does not matter', - ), - 'model.root.dep': ParsedModelNode( - name='dep', - alias='dep', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.dep', - fqn=['root', 'dep'], - package_name='root', - refs=[['events']], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='multi.sql', - original_file_path='multi.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql='does not matter', - ), + x_uid: self.x_node, + y_uid: self.y_node, + src_uid: self.src_node, } - - manifest = Manifest( - nodes=nodes, - macros={}, - docs={}, - generated_at=datetime.utcnow(), - disabled=[] - ) - - processed_manifest = ParserUtils.process_refs(manifest, 'root') - self.assertEqual( - processed_manifest.to_flat_graph()['nodes'], - { - 'model.snowplow.events': { - 'name': 'events', - 'alias': 'events', - 'database': 'test', - 'schema': 'analytics', - 'resource_type': 'model', - 'unique_id': 'model.snowplow.events', - 'fqn': ['snowplow', 'events'], - 'docrefs': [], - 'package_name': 'snowplow', - 'refs': [], - 'sources': [], - 'depends_on': { - 'nodes': [], - 'macros': [] - }, - 'config': self.disabled_config.to_dict(), - 'tags': [], - 'path': 'events.sql', - 'original_file_path': 'events.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': 'does not matter', - 'columns': {}, - 'description': '', - 'build_path': None, - 'patch_path': None, - }, - 'model.root.events': { - 'name': 'events', - 'alias': 'events', - 'database': 'test', - 'schema': 'analytics', - 'resource_type': 'model', - 'unique_id': 'model.root.events', - 'fqn': ['root', 'events'], - 'docrefs': [], - 'package_name': 'root', - 'refs': [], - 'sources': [], - 'depends_on': { - 'nodes': [], - 'macros': [] - }, - 'config': self.model_config.to_dict(), - 'tags': [], - 'path': 'events.sql', - 'original_file_path': 'events.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': 'does not matter', - 'columns': {}, - 'description': '', - 'build_path': None, - 'patch_path': None, - }, - 'model.root.dep': { - 'name': 'dep', - 'alias': 'dep', - 'database': 'test', - 'schema': 'analytics', - 'resource_type': 'model', - 'unique_id': 'model.root.dep', - 'fqn': ['root', 'dep'], - 'docrefs': [], - 'package_name': 'root', - 'refs': [['events']], - 'sources': [], - 'depends_on': { - 'nodes': ['model.root.events'], - 'macros': [] - }, - 'config': self.model_config.to_dict(), - 'tags': [], - 'path': 'multi.sql', - 'original_file_path': 'multi.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': 'does not matter', - 'columns': {}, - 'description': '', - 'build_path': None, - 'patch_path': None, - } - } - ) - - def test__in_model_config(self): - models = [{ - 'name': 'model_one', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'model_one.sql', - 'original_file_path': 'model_one.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("{{config({'materialized':'table'})}}" - "select * from events"), - }] - - self.model_config = self.model_config.replace(materialized='table') - - parser = ModelParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(models), - { - 'model.root.model_one': ParsedModelNode( - alias='model_one', - name='model_one', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.model_one', - fqn=['root', 'model_one'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - root_path=get_os_path('/usr/src/app'), - path='model_one.sql', - original_file_path='model_one.sql', - raw_sql=self.find_input_by_name( - models, 'model_one').get('raw_sql'), - description='', - columns={} - ) - }, - [] - ) - - def test__root_project_config(self): - self.root_project_config.models = { - 'materialized': 'ephemeral', - 'root': { - 'view': { - 'materialized': 'view' - } - } + docs = { + 'otherproject.my_doc': mock.MagicMock(block_contents='some docs') } - - models = [{ - 'name': 'table', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'table.sql', - 'original_file_path': 'table.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("{{config({'materialized':'table'})}}" - "select * from events"), - }, { - 'name': 'ephemeral', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'ephemeral.sql', - 'original_file_path': 'ephemeral.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("select * from events"), - }, { - 'name': 'view', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'view.sql', - 'original_file_path': 'view.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("select * from events"), - }] - - self.model_config = self.model_config.replace(materialized='table') - ephemeral_config = self.model_config.replace(materialized='ephemeral') - view_config = self.model_config.replace(materialized='view') - - parser = ModelParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(models), - { - 'model.root.table': ParsedModelNode( - alias='table', - name='table', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.table', - fqn=['root', 'table'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - path='table.sql', - original_file_path='table.sql', - config=self.model_config, - tags=[], - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'table').get('raw_sql'), - description='', - columns={} - ), - 'model.root.ephemeral': ParsedModelNode( - alias='ephemeral', - name='ephemeral', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.ephemeral', - fqn=['root', 'ephemeral'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - path='ephemeral.sql', - original_file_path='ephemeral.sql', - config=ephemeral_config, - tags=[], - root_path=get_os_path('/usr/src/app'), - raw_sql=self.find_input_by_name( - models, 'ephemeral').get('raw_sql'), - description='', - columns={} - ), - 'model.root.view': ParsedModelNode( - alias='view', - name='view', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.view', - fqn=['root', 'view'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - path='view.sql', - original_file_path='view.sql', - root_path=get_os_path('/usr/src/app'), - config=view_config, - tags=[], - raw_sql=self.find_input_by_name( - models, 'ephemeral').get('raw_sql'), - description='', - columns={} - ), - }, - [] - ) - - def test__other_project_config(self): - self.root_project_config.models = { - 'materialized': 'ephemeral', - 'root': { - 'view': { - 'materialized': 'view' - } - }, - 'snowplow': { - 'enabled': False, - 'views': { - 'materialized': 'view', - 'multi_sort': { - 'enabled': True, - 'materialized': 'table' - } - } - } - } - - self.snowplow_project_config.models = { - 'snowplow': { - 'enabled': False, - 'views': { - 'materialized': 'table', - 'sort': 'timestamp', - 'multi_sort': { - 'sort': ['timestamp', 'id'], - } - } - } - } - - models = [{ - 'name': 'table', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'table.sql', - 'original_file_path': 'table.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("{{config({'materialized':'table'})}}" - "select * from events"), - }, { - 'name': 'ephemeral', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'ephemeral.sql', - 'original_file_path': 'ephemeral.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("select * from events"), - }, { - 'name': 'view', - 'resource_type': 'model', - 'package_name': 'root', - 'path': 'view.sql', - 'original_file_path': 'view.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("select * from events"), - }, { - 'name': 'disabled', - 'resource_type': 'model', - 'package_name': 'snowplow', - 'path': 'disabled.sql', - 'original_file_path': 'disabled.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("select * from events"), - }, { - 'name': 'package', - 'resource_type': 'model', - 'package_name': 'snowplow', - 'path': get_os_path('views/package.sql'), - 'original_file_path': get_os_path('views/package.sql'), - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("select * from events"), - }, { - 'name': 'multi_sort', - 'resource_type': 'model', - 'package_name': 'snowplow', - 'path': get_os_path('views/multi_sort.sql'), - 'original_file_path': get_os_path('views/multi_sort.sql'), - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': ("select * from events"), - }] - - self.model_config = self.model_config.replace(materialized='table') - - ephemeral_config = self.model_config.replace( - materialized='ephemeral' - ) - view_config = self.model_config.replace( - materialized='view' - ) - disabled_config = self.model_config.replace( - materialized='ephemeral', - enabled=False, - ) - sort_config = self.model_config.replace( - materialized='view', - enabled=False, - sort='timestamp', - ) - multi_sort_config = self.model_config.replace( - materialized='table', - sort=['timestamp', 'id'], - ) - - parser = ModelParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(models), - parsed={ - 'model.root.table': ParsedModelNode( - alias='table', - name='table', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.table', - fqn=['root', 'table'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - path='table.sql', - original_file_path='table.sql', - root_path=get_os_path('/usr/src/app'), - config=self.model_config, - tags=[], - raw_sql=self.find_input_by_name( - models, 'table').get('raw_sql'), - description='', - columns={} - ), - 'model.root.ephemeral': ParsedModelNode( - alias='ephemeral', - name='ephemeral', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.ephemeral', - fqn=['root', 'ephemeral'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - path='ephemeral.sql', - original_file_path='ephemeral.sql', - root_path=get_os_path('/usr/src/app'), - config=ephemeral_config, - tags=[], - raw_sql=self.find_input_by_name( - models, 'ephemeral').get('raw_sql'), - description='', - columns={} - ), - 'model.root.view': ParsedModelNode( - alias='view', - name='view', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.view', - fqn=['root', 'view'], - package_name='root', - refs=[], - sources=[], - depends_on=DependsOn(), - path='view.sql', - original_file_path='view.sql', - root_path=get_os_path('/usr/src/app'), - config=view_config, - tags=[], - raw_sql=self.find_input_by_name( - models, 'view').get('raw_sql'), - description='', - columns={} - ), - 'model.snowplow.multi_sort': ParsedModelNode( - alias='multi_sort', - name='multi_sort', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.snowplow.multi_sort', - fqn=['snowplow', 'views', 'multi_sort'], - package_name='snowplow', - refs=[], - sources=[], - depends_on=DependsOn(), - path=get_os_path('views/multi_sort.sql'), - original_file_path=get_os_path('views/multi_sort.sql'), - root_path=get_os_path('/usr/src/app'), - config=multi_sort_config, - tags=[], - raw_sql=self.find_input_by_name( - models, 'multi_sort').get('raw_sql'), - description='', - columns={} - ), - }, - disabled=[ - ParsedModelNode( - name='disabled', - resource_type=NodeType.Model, - package_name='snowplow', - path='disabled.sql', - original_file_path='disabled.sql', - root_path=get_os_path('/usr/src/app'), - raw_sql=("select * from events"), - database='test', - schema='analytics', - refs=[], - sources=[], - depends_on=DependsOn(), - config=disabled_config, - tags=[], - alias='disabled', - unique_id='model.snowplow.disabled', - fqn=['snowplow', 'disabled'], - columns={} - ), - ParsedModelNode( - name='package', - resource_type=NodeType.Model, - package_name='snowplow', - path=get_os_path('views/package.sql'), - original_file_path=get_os_path('views/package.sql'), - root_path=get_os_path('/usr/src/app'), - raw_sql=("select * from events"), - database='test', - schema='analytics', - refs=[], - sources=[], - depends_on=DependsOn(), - config=sort_config, - tags=[], - alias='package', - unique_id='model.snowplow.package', - fqn=['snowplow', 'views', 'package'], - columns={} - ) - ] - ) - - def test__simple_data_test(self): - tests = [{ - 'name': 'no_events', - 'resource_type': 'test', - 'package_name': 'root', - 'path': 'no_events.sql', - 'original_file_path': 'no_events.sql', - 'root_path': get_os_path('/usr/src/app'), - 'raw_sql': "select * from {{ref('base')}}" - }] - - parser = DataTestParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(tests), - { - 'test.root.no_events': ParsedTestNode( - alias='no_events', - name='no_events', - database='test', - schema='analytics', - resource_type=NodeType.Test, - unique_id='test.root.no_events', - fqn=['root', 'no_events'], - package_name='root', - refs=[['base']], - sources=[], - depends_on=DependsOn(), - config=self.test_config, - path='no_events.sql', - original_file_path='no_events.sql', - root_path=get_os_path('/usr/src/app'), - tags=[], - raw_sql=self.find_input_by_name( - tests, 'no_events').get('raw_sql'), - description='', - columns={} - ) - }, - [] - ) - - def test__simple_macro(self): - macro_file_contents = """ -{% macro simple(a, b) %} - {{a}} + {{b}} -{% endmacro %} -""" - parser = MacroParser(None, {}) - result = parser.parse_macro_file( - macro_file_path='simple_macro.sql', - macro_file_contents=macro_file_contents, - root_path=get_os_path('/usr/src/app'), - package_name='root', - resource_type=NodeType.Macro) - - self.assertTrue(callable(result['macro.root.simple'].generator)) - - self.assertEqual( - result, - { - 'macro.root.simple': ParsedMacro.from_dict({ - 'name': 'simple', - 'resource_type': 'macro', - 'unique_id': 'macro.root.simple', - 'package_name': 'root', - 'depends_on': { - 'macros': [] - }, - 'original_file_path': 'simple_macro.sql', - 'root_path': get_os_path('/usr/src/app'), - 'tags': [], - 'path': 'simple_macro.sql', - 'raw_sql': macro_file_contents, - }) - } - ) - - def test__simple_macro_used_in_model(self): - macro_file_contents = """ -{% macro simple(a, b) %} - {{a}} + {{b}} -{% endmacro %} -""" - parser = MacroParser(None, {}) - result = parser.parse_macro_file( - macro_file_path='simple_macro.sql', - macro_file_contents=macro_file_contents, - root_path=get_os_path('/usr/src/app'), - package_name='root', - resource_type=NodeType.Macro) - - self.assertTrue(callable(result['macro.root.simple'].generator)) - - self.assertEqual( - result, - { - 'macro.root.simple': ParsedMacro.from_dict({ - 'name': 'simple', - 'resource_type': 'macro', - 'unique_id': 'macro.root.simple', - 'package_name': 'root', - 'depends_on': { - 'macros': [] - }, - 'original_file_path': 'simple_macro.sql', - 'root_path': get_os_path('/usr/src/app'), - 'tags': [], - 'path': 'simple_macro.sql', - 'raw_sql': macro_file_contents, - }), - } - ) - - models = [{ - 'name': 'model_one', - 'resource_type': 'model', - 'package_name': 'root', - 'original_file_path': 'model_one.sql', - 'root_path': get_os_path('/usr/src/app'), - 'path': 'model_one.sql', - 'raw_sql': ("select *, {{package.simple(1, 2)}} from events"), - }] - - parser = ModelParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(models), - { - 'model.root.model_one': ParsedModelNode( - alias='model_one', - name='model_one', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.model_one', - fqn=['root', 'model_one'], - package_name='root', - original_file_path='model_one.sql', - root_path=get_os_path('/usr/src/app'), - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='model_one.sql', - raw_sql=self.find_input_by_name( - models, 'model_one').get('raw_sql'), - description='', - columns={} - ) - }, - [] - ) - - def test__macro_no_explicit_project_used_in_model(self): - models = [{ - 'name': 'model_one', - 'resource_type': 'model', - 'package_name': 'root', - 'root_path': get_os_path('/usr/src/app'), - 'path': 'model_one.sql', - 'original_file_path': 'model_one.sql', - 'raw_sql': ("select *, {{ simple(1, 2) }} from events"), - }] - - parser = ModelParser( - self.root_project_config, - self.all_projects, - self.macro_manifest - ) - - self._assert_parsed_sql_nodes( - parser.parse_sql_nodes(models), - { - 'model.root.model_one': ParsedModelNode( - alias='model_one', - name='model_one', - database='test', - schema='analytics', - resource_type=NodeType.Model, - unique_id='model.root.model_one', - fqn=['root', 'model_one'], - package_name='root', - root_path=get_os_path('/usr/src/app'), - refs=[], - sources=[], - depends_on=DependsOn(), - config=self.model_config, - tags=[], - path='model_one.sql', - original_file_path='model_one.sql', - raw_sql=self.find_input_by_name( - models, 'model_one').get('raw_sql'), - description='', - columns={} - ) - }, - [] - ) + self.manifest = Manifest( + nodes=nodes, macros={}, docs=docs, disabled=[], files={}, generated_at=mock.MagicMock() + ) + + def test_resolve_docs(self): + # no error. TODO: real test + result = ParserUtils.process_docs(self.manifest, 'project') + self.assertIs(result, self.manifest) + self.assertEqual(self.x_node.description, 'other_project: some docs') + self.assertEqual(self.y_node.description, 'some docs') + + def test_resolve_sources(self): + result = ParserUtils.process_sources(self.manifest, 'project') + self.assertIs(result, self.manifest) + self.x_node.depends_on.nodes.append.assert_called_once_with('source.thirdproject.src.tbl') + + def test_resolve_refs(self): + result = ParserUtils.process_refs(self.manifest, 'project') + self.assertIs(result, self.manifest) + self.y_node.depends_on.nodes.append.assert_called_once_with('model.project.x') diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index be8c75fb5cc..0cc97923270 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -6,6 +6,7 @@ from dbt.adapters.postgres import PostgresAdapter from dbt.exceptions import ValidationException from dbt.logger import GLOBAL_LOGGER as logger # noqa +from dbt.parser.results import ParseResult from psycopg2 import extensions as psycopg2_extensions from psycopg2 import DatabaseError, Error import agate @@ -232,10 +233,15 @@ def setUp(self): self.adapter.acquire_connection() inject_adapter(self.adapter) + self.load_patch = mock.patch('dbt.loader.make_parse_result') + self.mock_parse_result = self.load_patch.start() + self.mock_parse_result.return_value = ParseResult.rpc() + def tearDown(self): # we want a unique self.handle every time. self.adapter.cleanup_connections() self.patcher.stop() + self.load_patch.stop() def test_quoting_on_drop_schema(self): self.adapter.drop_schema(database='postgres', schema='test_schema') diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py index 79a06320b95..db72d67bbe6 100644 --- a/test/unit/test_redshift_adapter.py +++ b/test/unit/test_redshift_adapter.py @@ -19,6 +19,7 @@ def fetch_cluster_credentials(*args, **kwargs): 'DbPassword': 'tmp_password' } + class TestRedshiftAdapter(unittest.TestCase): def setUp(self): diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index 3fea43a2f8f..73a5e4edaa1 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -5,6 +5,7 @@ from dbt.adapters.snowflake import SnowflakeAdapter from dbt.logger import GLOBAL_LOGGER as logger # noqa +from dbt.parser.results import ParseResult from snowflake import connector as snowflake_connector from .utils import config_from_parts_or_dicts, inject_adapter, mock_connection @@ -48,15 +49,21 @@ def setUp(self): 'dbt.adapters.snowflake.connections.snowflake.connector.connect') self.snowflake = self.patcher.start() + self.load_patch = mock.patch('dbt.loader.make_parse_result') + self.mock_parse_result = self.load_patch.start() + self.mock_parse_result.return_value = ParseResult.rpc() + self.snowflake.return_value = self.handle self.adapter = SnowflakeAdapter(self.config) self.adapter.acquire_connection() inject_adapter(self.adapter) + def tearDown(self): # we want a unique self.handle every time. self.adapter.cleanup_connections() self.patcher.stop() + self.load_patch.stop() def test_quoting_on_drop_schema(self): self.adapter.drop_schema( diff --git a/test/unit/test_source_config.py b/test/unit/test_source_config.py new file mode 100644 index 00000000000..82e0547ebcd --- /dev/null +++ b/test/unit/test_source_config.py @@ -0,0 +1,139 @@ +import os +from unittest import TestCase, mock + +import dbt.flags +from dbt.node_types import NodeType +from dbt.source_config import SourceConfig + +from .utils import config_from_parts_or_dicts + + +class SourceConfigTest(TestCase): + def setUp(self): + dbt.flags.STRICT_MODE = True + dbt.flags.WARN_ERROR = True + + self.maxDiff = None + + profile_data = { + 'target': 'test', + 'quoting': {}, + 'outputs': { + 'test': { + 'type': 'redshift', + 'host': 'localhost', + 'schema': 'analytics', + 'user': 'test', + 'pass': 'test', + 'dbname': 'test', + 'port': 1, + } + } + } + + root_project = { + 'name': 'root', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + } + + self.root_project_config = config_from_parts_or_dicts( + project=root_project, + profile=profile_data, + cli_vars='{"test_schema_name": "foo"}' + ) + + snowplow_project = { + 'name': 'snowplow', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('./dbt_modules/snowplow'), + } + + self.snowplow_project_config = config_from_parts_or_dicts( + project=snowplow_project, profile=profile_data + ) + + self.all_projects = { + 'root': self.root_project_config, + 'snowplow': self.snowplow_project_config + } + self.patcher = mock.patch('dbt.context.parser.get_adapter') + self.factory = self.patcher.start() + + def tearDown(self): + self.patcher.stop() + + def test__source_config_single_call(self): + cfg = SourceConfig(self.root_project_config, self.root_project_config, + ['root', 'x'], NodeType.Model) + cfg.update_in_model_config({ + 'materialized': 'something', + 'sort': 'my sort key', + 'pre-hook': 'my pre run hook', + 'vars': {'a': 1, 'b': 2}, + }) + expect = { + 'column_types': {}, + 'enabled': True, + 'materialized': 'something', + 'post-hook': [], + 'pre-hook': ['my pre run hook'], + 'persist_docs': {}, + 'quoting': {}, + 'sort': 'my sort key', + 'tags': [], + 'vars': {'a': 1, 'b': 2}, + } + self.assertEqual(cfg.config, expect) + + def test__source_config_multiple_calls(self): + cfg = SourceConfig(self.root_project_config, self.root_project_config, + ['root', 'x'], NodeType.Model) + cfg.update_in_model_config({ + 'materialized': 'something', + 'sort': 'my sort key', + 'pre-hook': 'my pre run hook', + 'vars': {'a': 1, 'b': 2}, + }) + cfg.update_in_model_config({ + 'materialized': 'something else', + 'pre-hook': ['my other pre run hook', 'another pre run hook'], + 'vars': {'a': 4, 'c': 3}, + }) + expect = { + 'column_types': {}, + 'enabled': True, + 'materialized': 'something else', + 'persist_docs': {}, + 'post-hook': [], + 'pre-hook': [ + 'my pre run hook', + 'my other pre run hook', + 'another pre run hook', + ], + 'quoting': {}, + 'sort': 'my sort key', + 'tags': [], + 'vars': {'a': 4, 'b': 2, 'c': 3}, + } + self.assertEqual(cfg.config, expect) + + def test_source_config_all_keys_accounted_for(self): + used_keys = frozenset(SourceConfig.AppendListFields) | \ + frozenset(SourceConfig.ExtendDictFields) | \ + frozenset(SourceConfig.ClobberFields) + + self.assertEqual(used_keys, frozenset(SourceConfig.ConfigKeys)) + + def test__source_config_wrong_type(self): + # ExtendDict fields should handle non-dict inputs gracefully + self.root_project_config.models = {'persist_docs': False} + cfg = SourceConfig(self.root_project_config, self.root_project_config, + ['root', 'x'], NodeType.Model) + + with self.assertRaises(dbt.exceptions.CompilationException) as exc: + cfg.get_project_config(self.root_project_config) + + self.assertIn('must be a dict', str(exc.exception)) diff --git a/test/unit/utils.py b/test/unit/utils.py index ffc0529f790..51d7cf45d8e 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -3,12 +3,26 @@ Note that all imports should be inside the functions to avoid import/mocking issues. """ +import os from unittest import mock from unittest import TestCase from hologram import ValidationError +def normalize(path): + """On windows, neither is enough on its own: + + >>> normcase('C:\\documents/ALL CAPS/subdir\\..') + 'c:\\documents\\all caps\\subdir\\..' + >>> normpath('C:\\documents/ALL CAPS/subdir\\..') + 'C:\\documents\\ALL CAPS' + >>> normpath(normcase('C:\\documents/ALL CAPS/subdir\\..')) + 'c:\\documents\\all caps' + """ + return os.path.normcase(os.path.normpath(path)) + + class Obj: which = 'blah' @@ -33,6 +47,7 @@ def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): cli_vars) args = Obj() args.vars = repr(cli_vars) + args.profile_dir = '/dev/null' return RuntimeConfig.from_parts( project=project, profile=profile, diff --git a/third-party-stubs/agate/__init__.pyi b/third-party-stubs/agate/__init__.pyi new file mode 100644 index 00000000000..98b35c03216 --- /dev/null +++ b/third-party-stubs/agate/__init__.pyi @@ -0,0 +1,27 @@ +from typing import Any, Optional + +from . import data_types + + +class Table: + def __init__(self, rows: Any, column_names: Optional[Any] = ..., column_types: Optional[Any] = ..., row_names: Optional[Any] = ..., _is_fork: bool = ...) -> None: ... + def __len__(self): ... + def __iter__(self): ... + def __getitem__(self, key: Any): ... + @property + def column_types(self): ... + @property + def column_names(self): ... + @property + def row_names(self): ... + @property + def columns(self): ... + @property + def rows(self): ... + def print_csv(self, **kwargs: Any) -> None: ... + def print_json(self, **kwargs: Any) -> None: ... + + +class TypeTester: + def __init__(self, force: Any = ..., limit: Optional[Any] = ..., types: Optional[Any] = ...) -> None: ... + def run(self, rows: Any, column_names: Any): ... diff --git a/third-party-stubs/agate/data_types.pyi b/third-party-stubs/agate/data_types.pyi new file mode 100644 index 00000000000..e4243732fb1 --- /dev/null +++ b/third-party-stubs/agate/data_types.pyi @@ -0,0 +1,72 @@ + +from typing import Any, Optional + +DEFAULT_NULL_VALUES: Any + + +class DataType: + null_values: Any = ... + def __init__(self, null_values: Any = ...) -> None: ... + def test(self, d: Any): ... + def cast(self, d: Any) -> None: ... + def csvify(self, d: Any): ... + def jsonify(self, d: Any): ... + + +DEFAULT_TRUE_VALUES: Any +DEFAULT_FALSE_VALUES: Any + + +class Boolean(DataType): + true_values: Any = ... + false_values: Any = ... + def __init__(self, true_values: Any = ..., false_values: Any = ..., null_values: Any = ...) -> None: ... + def cast(self, d: Any): ... + def jsonify(self, d: Any): ... + + +ZERO_DT: Any + + +class Date(DataType): + date_format: Any = ... + parser: Any = ... + def __init__(self, date_format: Optional[Any] = ..., **kwargs: Any) -> None: ... + def cast(self, d: Any): ... + def csvify(self, d: Any): ... + def jsonify(self, d: Any): ... + + +class DateTime(DataType): + datetime_format: Any = ... + timezone: Any = ... + def __init__(self, datetime_format: Optional[Any] = ..., timezone: Optional[Any] = ..., **kwargs: Any) -> None: ... + def cast(self, d: Any): ... + def csvify(self, d: Any): ... + def jsonify(self, d: Any): ... + + +DEFAULT_CURRENCY_SYMBOLS: Any +POSITIVE: Any +NEGATIVE: Any + + +class Number(DataType): + locale: Any = ... + currency_symbols: Any = ... + group_symbol: Any = ... + decimal_symbol: Any = ... + def __init__(self, locale: str = ..., group_symbol: Optional[Any] = ..., decimal_symbol: Optional[Any] = ..., currency_symbols: Any = ..., **kwargs: Any) -> None: ... + def cast(self, d: Any): ... + def jsonify(self, d: Any): ... + + +class TimeDelta(DataType): + def cast(self, d: Any): ... + + +class Text(DataType): + cast_nulls: Any = ... + def __init__(self, cast_nulls: bool = ..., **kwargs: Any) -> None: ... + def cast(self, d: Any): ... + diff --git a/third-party-stubs/cdecimal/__init__.pyi b/third-party-stubs/cdecimal/__init__.pyi new file mode 100644 index 00000000000..d21582ce67c --- /dev/null +++ b/third-party-stubs/cdecimal/__init__.pyi @@ -0,0 +1,2 @@ +class Decimal: + pass \ No newline at end of file diff --git a/third-party-stubs/colorama/__init__.pyi b/third-party-stubs/colorama/__init__.pyi new file mode 100644 index 00000000000..693e7f9891d --- /dev/null +++ b/third-party-stubs/colorama/__init__.pyi @@ -0,0 +1,12 @@ +from typing import Optional, Any + +class Fore: + RED: str = ... + GREEN: str = ... + YELLOW: str = ... + +class Style: + RESET_ALL: str = ... + + +def init(autoreset: bool = ..., convert: Optional[Any] = ..., strip: Optional[Any] = ..., wrap: bool = ...) -> None: ... \ No newline at end of file diff --git a/third-party-stubs/snowplow_tracker/__init__.pyi b/third-party-stubs/snowplow_tracker/__init__.pyi new file mode 100644 index 00000000000..6c88a057f94 --- /dev/null +++ b/third-party-stubs/snowplow_tracker/__init__.pyi @@ -0,0 +1,53 @@ +import logging +from typing import Union, Optional, List, Any, Dict + +class Subject: + def __init__(self) -> None: ... + def set_platform(self, value: Any): ... + def set_user_id(self, user_id: Any): ... + def set_screen_resolution(self, width: Any, height: Any): ... + def set_viewport(self, width: Any, height: Any): ... + def set_color_depth(self, depth: Any): ... + def set_timezone(self, timezone: Any): ... + def set_lang(self, lang: Any): ... + def set_domain_user_id(self, duid: Any): ... + def set_ip_address(self, ip: Any): ... + def set_useragent(self, ua: Any): ... + def set_network_user_id(self, nuid: Any): ... + + +logger: logging.Logger + + +class Emitter: + def __init__(self, endpoint: str, protocol: str = ..., port: Optional[int] = ..., method: str = ..., buffer_size: Optional[int] = ..., on_success: Optional[Any] = ..., on_failure: Optional[Any] = ..., byte_limit: Optional[int] = ...) -> None: ... + + +class Tracker: + emitters: Union[List[Any], Any] = ... + subject: Optional[Subject] = ... + namespace: Optional[str] = ... + app_id: Optional[str] = ... + encode_base64: bool = ... + + def __init__(self, emitters: Union[List[Any], Any], subject: Optional[Subject] = ..., namespace: Optional[str] = ..., app_id: Optional[str] = ..., encode_base64: bool = ...) -> None: ... + # @staticmethod + # def get_uuid(): ... + # @staticmethod + # def get_timestamp(tstamp: Optional[Any] = ...): ... + # def track(self, pb: Any): ... + # def complete_payload(self, pb: Any, context: Any, tstamp: Any): ... + # def track_struct_event(self, category: Any, action: Any, label: Optional[Any] = ..., property_: Optional[Any] = ..., value: Optional[Any] = ..., context: Optional[Any] = ..., tstamp: Optional[Any] = ...): ... + # def track_unstruct_event(self, event_json: Any, context: Optional[Any] = ..., tstamp: Optional[Any] = ...): ... + # track_self_describing_event: Any = ... + # def flush(self, asynchronous: bool = ...): ... + # def set_subject(self, subject: Any): ... + # def add_emitter(self, emitter: Any): ... + + +class SelfDescribingJson: + schema: Any = ... + data: Any = ... + def __init__(self, schema: Any, data: Any) -> None: ... + def to_json(self) -> Dict[str, Any]: ... + def to_string(self) -> str: ... diff --git a/tox.ini b/tox.ini index 78f78319636..05e0881cad7 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] skipsdist = True -envlist = unit-py36, integration-postgres-py36, integration-redshift-py36, integration-snowflake-py36, flake8, integration-bigquery-py36 +envlist = unit-py36, integration-postgres-py36, integration-redshift-py36, integration-snowflake-py36, flake8, integration-bigquery-py36, mypy [testenv:flake8] @@ -9,6 +9,51 @@ commands = /bin/bash -c '$(which flake8) --select=E,W,F --ignore=W504 core/dbt p deps = -r{toxinidir}/dev_requirements.txt +[testenv:mypy] +basepython = python3.6 +commands = /bin/bash -c '$(which mypy) --namespace-packages \ + core/dbt/clients \ + core/dbt/config \ + core/dbt/exceptions.py \ + core/dbt/flags.py \ + core/dbt/helper_types.py \ + core/dbt/hooks.py \ + core/dbt/include \ + core/dbt/links.py \ + core/dbt/loader.py \ + core/dbt/logger.py \ + core/dbt/main.py \ + core/dbt/node_runners.py \ + core/dbt/node_types.py \ + core/dbt/parser \ + core/dbt/profiler.py \ + core/dbt/py.typed \ + core/dbt/semver.py \ + core/dbt/source_config.py \ + core/dbt/task/base.py \ + core/dbt/task/clean.py \ + core/dbt/task/debug.py \ + core/dbt/task/freshness.py \ + core/dbt/task/generate.py \ + core/dbt/task/init.py \ + core/dbt/task/list.py \ + core/dbt/task/run_operation.py \ + core/dbt/task/runnable.py \ + core/dbt/task/seed.py \ + core/dbt/task/serve.py \ + core/dbt/task/snapshot.py \ + core/dbt/task/test.py \ + core/dbt/tracking.py \ + core/dbt/ui \ + core/dbt/utils.py \ + core/dbt/version.py \ + core/dbt/writer.py' +setenv = + MYPYPATH={toxinidir}/third-party-stubs +deps = + -r{toxinidir}/requirements.txt + -r{toxinidir}/dev_requirements.txt + [testenv:unit-py36] basepython = python3.6 commands = /bin/bash -c '{envpython} -m pytest --durations 0 -v {posargs} -n4 test/unit'