diff --git a/core/dbt/contracts/rpc.py b/core/dbt/contracts/rpc.py index a90e5986d2e..aa7165a2cc2 100644 --- a/core/dbt/contracts/rpc.py +++ b/core/dbt/contracts/rpc.py @@ -1,8 +1,13 @@ +import enum +import os +import uuid from dataclasses import dataclass, field +from datetime import datetime, timedelta from numbers import Real -from typing import Optional, Union, List, Any, Dict +from typing import Optional, Union, List, Any, Dict, Type from hologram import JsonSchemaMixin +from hologram.helpers import StrEnum from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.results import ( @@ -10,7 +15,12 @@ CatalogResults, ExecutionResult, ) +from dbt.exceptions import InternalException from dbt.logger import LogMessage +from dbt.utils import restrict_to + + +TaskTags = Optional[Dict[str, Any]] # Inputs @@ -18,7 +28,7 @@ @dataclass class RPCParameters(JsonSchemaMixin): timeout: Optional[Real] - task_tags: Optional[Dict[str, Any]] + task_tags: TaskTags @dataclass @@ -55,21 +65,34 @@ class RPCCliParameters(RPCParameters): cli: str +@dataclass +class RPCNoParameters(RPCParameters): + pass + + # Outputs +@dataclass +class RemoteResult(JsonSchemaMixin): + logs: List[LogMessage] + @dataclass -class RemoteCatalogResults(CatalogResults): - logs: List[LogMessage] = field(default_factory=list) +class RemoteEmptyResult(RemoteResult): + pass + + +@dataclass +class RemoteCatalogResults(CatalogResults, RemoteResult): + pass @dataclass -class RemoteCompileResult(JsonSchemaMixin): +class RemoteCompileResult(RemoteResult): raw_sql: str compiled_sql: str node: CompileResultNode timing: List[TimingInfo] - logs: List[LogMessage] @property def error(self): @@ -77,8 +100,8 @@ def error(self): @dataclass -class RemoteExecutionResult(ExecutionResult): - logs: List[LogMessage] +class RemoteExecutionResult(ExecutionResult, RemoteResult): + pass @dataclass @@ -90,3 +113,312 @@ class ResultTable(JsonSchemaMixin): @dataclass class RemoteRunResult(RemoteCompileResult): table: ResultTable + + +RPCResult = Union[ + RemoteCompileResult, + RemoteExecutionResult, + RemoteCatalogResults, + RemoteEmptyResult, +] + + +# GC types + + +class GCResultState(StrEnum): + Deleted = 'deleted' # successful GC + Missing = 'missing' # nothing to GC + Running = 'running' # can't GC + + +@dataclass +class GCResultSet(JsonSchemaMixin): + deleted: List[uuid.UUID] = field(default_factory=list) + missing: List[uuid.UUID] = field(default_factory=list) + running: List[uuid.UUID] = field(default_factory=list) + + def add_result(self, task_id: uuid.UUID, status: GCResultState): + if status == GCResultState.Missing: + self.missing.append(task_id) + elif status == GCResultState.Running: + self.running.append(task_id) + elif status == GCResultState.Deleted: + self.deleted.append(task_id) + else: + raise InternalException( + f'Got invalid status in add_result: {status}' + ) + + +@dataclass +class GCSettings(JsonSchemaMixin): + # start evicting the longest-ago-ended tasks here + maxsize: int + # start evicting all tasks before now - auto_reap_age when we have this + # many tasks in the table + reapsize: int + # a positive timedelta indicating how far back we should go + auto_reap_age: timedelta + + +# Task management types + + +class TaskHandlerState(StrEnum): + NotStarted = 'not started' + Initializing = 'initializing' + Running = 'running' + Success = 'success' + Error = 'error' + Killed = 'killed' + Failed = 'failed' + + def __lt__(self, other) -> bool: + """A logical ordering for TaskHandlerState: + + NotStarted < Initializing < Running < (Success, Error, Killed, Failed) + """ + if not isinstance(other, TaskHandlerState): + raise TypeError('cannot compare to non-TaskHandlerState') + order = (self.NotStarted, self.Initializing, self.Running) + smaller = set() + for value in order: + smaller.add(value) + if self == value: + return other not in smaller + + return False + + def __le__(self, other) -> bool: + # so that ((Success <= Error) is True) + return ((self < other) or + (self == other) or + (self.finished and other.finished)) + + def __gt__(self, other) -> bool: + if not isinstance(other, TaskHandlerState): + raise TypeError('cannot compare to non-TaskHandlerState') + order = (self.NotStarted, self.Initializing, self.Running) + smaller = set() + for value in order: + smaller.add(value) + if self == value: + return other in smaller + return other in smaller + + def __ge__(self, other) -> bool: + # so that ((Success <= Error) is True) + return ((self > other) or + (self == other) or + (self.finished and other.finished)) + + @property + def finished(self) -> bool: + return self in (self.Error, self.Success, self.Killed, self.Failed) + + +@dataclass +class TaskRow(JsonSchemaMixin): + task_id: uuid.UUID + request_id: Union[str, int] + request_source: str + method: str + state: TaskHandlerState + start: Optional[datetime] + end: Optional[datetime] + elapsed: Optional[float] + timeout: Optional[float] + tags: TaskTags + + +@dataclass +class PSResult(JsonSchemaMixin): + rows: List[TaskRow] + + +class KillResultStatus(StrEnum): + Missing = 'missing' + NotStarted = 'not_started' + Killed = 'killed' + Finished = 'finished' + + +@dataclass +class KillResult(JsonSchemaMixin): + status: KillResultStatus + + +# this is kind of carefuly structured: BlocksManifestTasks is implied by +# RequiresConfigReloadBefore and RequiresManifestReloadAfter +class RemoteMethodFlags(enum.Flag): + Empty = 0 + BlocksManifestTasks = 1 + RequiresConfigReloadBefore = 3 + RequiresManifestReloadAfter = 5 + + +# Polling types + + +@dataclass +class PollResult(JsonSchemaMixin): + tags: TaskTags = None + status: TaskHandlerState = TaskHandlerState.NotStarted + + +@dataclass +class PollRemoteEmptyCompleteResult(PollResult, RemoteEmptyResult): + status: TaskHandlerState = field( + metadata=restrict_to(TaskHandlerState.Success, + TaskHandlerState.Failed), + default=TaskHandlerState.Success, + ) + + @classmethod + def from_result( + cls: Type['PollRemoteEmptyCompleteResult'], + status: TaskHandlerState, + base: RemoteEmptyResult, + tags: TaskTags, + ) -> 'PollRemoteEmptyCompleteResult': + return cls( + status=status, + logs=base.logs, + tags=tags, + ) + + +@dataclass +class PollKilledResult(PollResult): + logs: List[LogMessage] = field(default_factory=list) + status: TaskHandlerState = field( + metadata=restrict_to(TaskHandlerState.Killed), + default=TaskHandlerState.Killed, + ) + + +@dataclass +class PollExecuteCompleteResult(PollResult, RemoteExecutionResult): + status: TaskHandlerState = field( + metadata=restrict_to(TaskHandlerState.Success, + TaskHandlerState.Failed), + default=TaskHandlerState.Success, + ) + + @classmethod + def from_result( + cls: Type['PollExecuteCompleteResult'], + status: TaskHandlerState, + base: RemoteExecutionResult, + tags: TaskTags, + ) -> 'PollExecuteCompleteResult': + return cls( + status=status, + results=base.results, + generated_at=base.generated_at, + elapsed_time=base.elapsed_time, + logs=base.logs, + tags=tags, + ) + + +@dataclass +class PollCompileCompleteResult(PollResult, RemoteCompileResult): + status: TaskHandlerState = field( + metadata=restrict_to(TaskHandlerState.Success, + TaskHandlerState.Failed), + default=TaskHandlerState.Success, + ) + + @classmethod + def from_result( + cls: Type['PollCompileCompleteResult'], + status: TaskHandlerState, + base: RemoteCompileResult, + tags: TaskTags, + ) -> 'PollCompileCompleteResult': + return cls( + status=status, + raw_sql=base.raw_sql, + compiled_sql=base.compiled_sql, + node=base.node, + timing=base.timing, + logs=base.logs, + tags=tags, + ) + + +@dataclass +class PollRunCompleteResult(PollResult, RemoteRunResult): + status: TaskHandlerState = field( + metadata=restrict_to(TaskHandlerState.Success, + TaskHandlerState.Failed), + default=TaskHandlerState.Success, + ) + + @classmethod + def from_result( + cls: Type['PollRunCompleteResult'], + status: TaskHandlerState, + base: RemoteRunResult, + tags: TaskTags, + ) -> 'PollRunCompleteResult': + return cls( + status=status, + raw_sql=base.raw_sql, + compiled_sql=base.compiled_sql, + node=base.node, + timing=base.timing, + logs=base.logs, + table=base.table, + tags=tags, + ) + + +@dataclass +class PollCatalogCompleteResult(PollResult, RemoteCatalogResults): + status: TaskHandlerState = field( + metadata=restrict_to(TaskHandlerState.Success, + TaskHandlerState.Failed), + default=TaskHandlerState.Success, + ) + + @classmethod + def from_result( + cls: Type['PollCatalogCompleteResult'], + status: TaskHandlerState, + base: RemoteCatalogResults, + tags: TaskTags, + ) -> 'PollCatalogCompleteResult': + return cls( + status=status, + nodes=base.nodes, + generated_at=base.generated_at, + _compile_results=base._compile_results, + logs=base.logs, + tags=tags, + ) + + +@dataclass +class PollInProgressResult(PollResult): + logs: List[LogMessage] = field(default_factory=list) + + +# Manifest parsing types + +class ManifestStatus(StrEnum): + Init = 'init' + Compiling = 'compiling' + Ready = 'ready' + Error = 'error' + + +@dataclass +class LastParse(JsonSchemaMixin): + status: ManifestStatus + error: Optional[Dict[str, Any]] = None + logs: Optional[List[Dict[str, Any]]] = None + timestamp: datetime = field(default_factory=datetime.utcnow) + pid: int = field(default_factory=os.getpid) diff --git a/core/dbt/deps/base.py b/core/dbt/deps/base.py new file mode 100644 index 00000000000..15670ab04bd --- /dev/null +++ b/core/dbt/deps/base.py @@ -0,0 +1,112 @@ +import abc +import os +import tempfile +from contextlib import contextmanager +from typing import List, Optional, Generic, TypeVar + +from dbt.clients import system +from dbt.contracts.project import ProjectPackageMetadata +from dbt.logger import GLOBAL_LOGGER as logger + +DOWNLOADS_PATH = None + + +def get_downloads_path(): + return DOWNLOADS_PATH + + +@contextmanager +def downloads_directory(): + global DOWNLOADS_PATH + remove_downloads = False + # the user might have set an environment variable. Set it to that, and do + # not remove it when finished. + if DOWNLOADS_PATH is None: + DOWNLOADS_PATH = os.getenv('DBT_DOWNLOADS_DIR') + remove_downloads = False + # if we are making a per-run temp directory, remove it at the end of + # successful runs + if DOWNLOADS_PATH is None: + DOWNLOADS_PATH = tempfile.mkdtemp(prefix='dbt-downloads-') + remove_downloads = True + + system.make_directory(DOWNLOADS_PATH) + logger.debug("Set downloads directory='{}'".format(DOWNLOADS_PATH)) + + yield DOWNLOADS_PATH + + if remove_downloads: + system.rmtree(DOWNLOADS_PATH) + DOWNLOADS_PATH = None + + +class BasePackage(metaclass=abc.ABCMeta): + @abc.abstractproperty + def name(self) -> str: + raise NotImplementedError + + def all_names(self) -> List[str]: + return [self.name] + + @abc.abstractmethod + def source_type(self) -> str: + raise NotImplementedError + + +class PinnedPackage(BasePackage): + def __init__(self) -> None: + self._cached_metadata: Optional[ProjectPackageMetadata] = None + + def __str__(self) -> str: + version = self.get_version() + if not version: + return self.name + + return '{}@{}'.format(self.name, version) + + @abc.abstractmethod + def get_version(self) -> Optional[str]: + raise NotImplementedError + + @abc.abstractmethod + def _fetch_metadata(self, project): + raise NotImplementedError + + @abc.abstractmethod + def install(self, project): + raise NotImplementedError + + @abc.abstractmethod + def nice_version_name(self): + raise NotImplementedError + + def fetch_metadata(self, project): + if not self._cached_metadata: + self._cached_metadata = self._fetch_metadata(project) + return self._cached_metadata + + def get_project_name(self, project): + metadata = self.fetch_metadata(project) + return metadata.name + + def get_installation_path(self, project): + dest_dirname = self.get_project_name(project) + return os.path.join(project.modules_path, dest_dirname) + + +SomePinned = TypeVar('SomePinned', bound=PinnedPackage) +SomeUnpinned = TypeVar('SomeUnpinned', bound='UnpinnedPackage') + + +class UnpinnedPackage(Generic[SomePinned], BasePackage): + @abc.abstractclassmethod + def from_contract(cls, contract): + raise NotImplementedError + + @abc.abstractmethod + def incorporate(self: SomeUnpinned, other: SomeUnpinned) -> SomeUnpinned: + raise NotImplementedError + + @abc.abstractmethod + def resolved(self) -> SomePinned: + raise NotImplementedError diff --git a/core/dbt/deps/git.py b/core/dbt/deps/git.py new file mode 100644 index 00000000000..52a5fc000ff --- /dev/null +++ b/core/dbt/deps/git.py @@ -0,0 +1,145 @@ +import os +import hashlib +from typing import List + +from dbt.clients import git, system +from dbt.config import Project +from dbt.contracts.project import ( + ProjectPackageMetadata, + GitPackage, +) +from dbt.deps.base import PinnedPackage, UnpinnedPackage, get_downloads_path +from dbt.exceptions import ( + ExecutableError, warn_or_error, raise_dependency_error +) +from dbt.logger import GLOBAL_LOGGER as logger +from dbt.ui import printer + +PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa + + +def md5sum(s: str): + return hashlib.md5(s.encode('latin-1')).hexdigest() + + +class GitPackageMixin: + def __init__(self, git: str) -> None: + super().__init__() + self.git = git + + @property + def name(self): + return self.git + + def source_type(self) -> str: + return 'git' + + +class GitPinnedPackage(GitPackageMixin, PinnedPackage): + def __init__( + self, git: str, revision: str, warn_unpinned: bool = True + ) -> None: + super().__init__(git) + self.revision = revision + self.warn_unpinned = warn_unpinned + self._checkout_name = md5sum(self.git) + + def get_version(self): + return self.revision + + def nice_version_name(self): + return 'revision {}'.format(self.revision) + + def _checkout(self): + """Performs a shallow clone of the repository into the downloads + directory. This function can be called repeatedly. If the project has + already been checked out at this version, it will be a no-op. Returns + the path to the checked out directory.""" + try: + dir_ = git.clone_and_checkout( + self.git, get_downloads_path(), branch=self.revision, + dirname=self._checkout_name + ) + except ExecutableError as exc: + if exc.cmd and exc.cmd[0] == 'git': + logger.error( + 'Make sure git is installed on your machine. More ' + 'information: ' + 'https://docs.getdbt.com/docs/package-management' + ) + raise + return os.path.join(get_downloads_path(), dir_) + + def _fetch_metadata(self, project) -> ProjectPackageMetadata: + path = self._checkout() + if self.revision == 'master' and self.warn_unpinned: + warn_or_error( + 'The git package "{}" is not pinned.\n\tThis can introduce ' + 'breaking changes into your project without warning!\n\nSee {}' + .format(self.git, PIN_PACKAGE_URL), + log_fmt=printer.yellow('WARNING: {}') + ) + loaded = Project.from_project_root(path, {}) + return ProjectPackageMetadata.from_project(loaded) + + def install(self, project): + dest_path = self.get_installation_path(project) + if os.path.exists(dest_path): + if system.path_is_symlink(dest_path): + system.remove_file(dest_path) + else: + system.rmdir(dest_path) + + system.move(self._checkout(), dest_path) + + +class GitUnpinnedPackage(GitPackageMixin, UnpinnedPackage[GitPinnedPackage]): + def __init__( + self, git: str, revisions: List[str], warn_unpinned: bool = True + ) -> None: + super().__init__(git) + self.revisions = revisions + self.warn_unpinned = warn_unpinned + + @classmethod + def from_contract( + cls, contract: GitPackage + ) -> 'GitUnpinnedPackage': + revisions = [contract.revision] if contract.revision else [] + + # we want to map None -> True + warn_unpinned = contract.warn_unpinned is not False + return cls(git=contract.git, revisions=revisions, + warn_unpinned=warn_unpinned) + + def all_names(self) -> List[str]: + if self.git.endswith('.git'): + other = self.git[:-4] + else: + other = self.git + '.git' + return [self.git, other] + + def incorporate( + self, other: 'GitUnpinnedPackage' + ) -> 'GitUnpinnedPackage': + warn_unpinned = self.warn_unpinned and other.warn_unpinned + + return GitUnpinnedPackage( + git=self.git, + revisions=self.revisions + other.revisions, + warn_unpinned=warn_unpinned, + ) + + def resolved(self) -> GitPinnedPackage: + requested = set(self.revisions) + if len(requested) == 0: + requested = {'master'} + elif len(requested) > 1: + raise_dependency_error( + 'git dependencies should contain exactly one version. ' + '{} contains: {}'.format(self.git, requested)) + + return GitPinnedPackage( + git=self.git, revision=requested.pop(), + warn_unpinned=self.warn_unpinned + ) diff --git a/core/dbt/deps/local.py b/core/dbt/deps/local.py new file mode 100644 index 00000000000..175c03e434b --- /dev/null +++ b/core/dbt/deps/local.py @@ -0,0 +1,82 @@ +import shutil + +from dbt.clients import system +from dbt.deps.base import PinnedPackage, UnpinnedPackage +from dbt.contracts.project import ( + ProjectPackageMetadata, + LocalPackage, +) +from dbt.logger import GLOBAL_LOGGER as logger + + +class LocalPackageMixin: + def __init__(self, local: str) -> None: + super().__init__() + self.local = local + + @property + def name(self): + return self.local + + def source_type(self): + return 'local' + + +class LocalPinnedPackage(LocalPackageMixin, PinnedPackage): + def __init__(self, local: str) -> None: + super().__init__(local) + + def get_version(self): + return None + + def nice_version_name(self): + return ''.format(self.local) + + def resolve_path(self, project): + return system.resolve_path_from_base( + self.local, + project.project_root, + ) + + def _fetch_metadata(self, project): + loaded = project.from_project_root(self.resolve_path(project), {}) + return ProjectPackageMetadata.from_project(loaded) + + def install(self, project): + src_path = self.resolve_path(project) + dest_path = self.get_installation_path(project) + + can_create_symlink = system.supports_symlinks() + + if system.path_exists(dest_path): + if not system.path_is_symlink(dest_path): + system.rmdir(dest_path) + else: + system.remove_file(dest_path) + + if can_create_symlink: + logger.debug(' Creating symlink to local dependency.') + system.make_symlink(src_path, dest_path) + + else: + logger.debug(' Symlinks are not available on this ' + 'OS, copying dependency.') + shutil.copytree(src_path, dest_path) + + +class LocalUnpinnedPackage( + LocalPackageMixin, UnpinnedPackage[LocalPinnedPackage] +): + @classmethod + def from_contract( + cls, contract: LocalPackage + ) -> 'LocalUnpinnedPackage': + return cls(local=contract.local) + + def incorporate( + self, other: 'LocalUnpinnedPackage' + ) -> 'LocalUnpinnedPackage': + return LocalUnpinnedPackage(local=self.local) + + def resolved(self) -> LocalPinnedPackage: + return LocalPinnedPackage(local=self.local) diff --git a/core/dbt/deps/registry.py b/core/dbt/deps/registry.py new file mode 100644 index 00000000000..e9200f90e56 --- /dev/null +++ b/core/dbt/deps/registry.py @@ -0,0 +1,124 @@ +import os +from typing import List + +from dbt import semver +from dbt.clients import registry, system +from dbt.contracts.project import ( + RegistryPackageMetadata, + RegistryPackage, +) +from dbt.deps.base import PinnedPackage, UnpinnedPackage, get_downloads_path +from dbt.exceptions import ( + package_version_not_found, + VersionsNotCompatibleException, + DependencyException, + package_not_found, +) + + +class RegistryPackageMixin: + def __init__(self, package: str) -> None: + super().__init__() + self.package = package + + @property + def name(self): + return self.package + + def source_type(self) -> str: + return 'hub' + + +class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage): + def __init__(self, package: str, version: str) -> None: + super().__init__(package) + self.version = version + + @property + def name(self): + return self.package + + def source_type(self): + return 'hub' + + def get_version(self): + return self.version + + def nice_version_name(self): + return 'version {}'.format(self.version) + + def _fetch_metadata(self, project) -> RegistryPackageMetadata: + dct = registry.package_version(self.package, self.version) + return RegistryPackageMetadata.from_dict(dct) + + def install(self, project): + metadata = self.fetch_metadata(project) + + tar_name = '{}.{}.tar.gz'.format(self.package, self.version) + tar_path = os.path.realpath( + os.path.join(get_downloads_path(), tar_name) + ) + system.make_directory(os.path.dirname(tar_path)) + + download_url = metadata.downloads.tarball + system.download(download_url, tar_path) + deps_path = project.modules_path + package_name = self.get_project_name(project) + system.untar_package(tar_path, deps_path, package_name) + + +class RegistryUnpinnedPackage( + RegistryPackageMixin, UnpinnedPackage[RegistryPinnedPackage] +): + def __init__( + self, package: str, versions: List[semver.VersionSpecifier] + ) -> None: + super().__init__(package) + self.versions = versions + + def _check_in_index(self): + index = registry.index_cached() + if self.package not in index: + package_not_found(self.package) + + @classmethod + def from_contract( + cls, contract: RegistryPackage + ) -> 'RegistryUnpinnedPackage': + raw_version = contract.version + if isinstance(raw_version, str): + raw_version = [raw_version] + + versions = [ + semver.VersionSpecifier.from_version_string(v) + for v in raw_version + ] + return cls(package=contract.package, versions=versions) + + def incorporate( + self, other: 'RegistryUnpinnedPackage' + ) -> 'RegistryUnpinnedPackage': + return RegistryUnpinnedPackage( + package=self.package, + versions=self.versions + other.versions, + ) + + def resolved(self) -> RegistryPinnedPackage: + self._check_in_index() + try: + range_ = semver.reduce_versions(*self.versions) + except VersionsNotCompatibleException as e: + new_msg = ('Version error for package {}: {}' + .format(self.name, e)) + raise DependencyException(new_msg) from e + + available = registry.get_available_versions(self.package) + + # for now, pick a version and then recurse. later on, + # we'll probably want to traverse multiple options + # so we can match packages. not going to make a difference + # right now. + target = semver.resolve_to_specific_version(range_, available) + if not target: + package_version_not_found(self.package, range_, available) + return RegistryPinnedPackage(package=self.package, version=target) diff --git a/core/dbt/deps/resolver.py b/core/dbt/deps/resolver.py new file mode 100644 index 00000000000..d966ee6f703 --- /dev/null +++ b/core/dbt/deps/resolver.py @@ -0,0 +1,128 @@ +from dataclasses import dataclass, field +from typing import Dict, List, NoReturn, Union, Type, Iterator + +from dbt.exceptions import raise_dependency_error, InternalException + +from dbt.deps.base import BasePackage, PinnedPackage, UnpinnedPackage +from dbt.deps.local import LocalUnpinnedPackage +from dbt.deps.git import GitUnpinnedPackage +from dbt.deps.registry import RegistryUnpinnedPackage + +from dbt.contracts.project import ( + LocalPackage, + GitPackage, + RegistryPackage, +) + +PackageContract = Union[LocalPackage, GitPackage, RegistryPackage] + + +@dataclass +class PackageListing: + packages: Dict[str, UnpinnedPackage] = field(default_factory=dict) + + def __len__(self): + return len(self.packages) + + def __bool__(self): + return bool(self.packages) + + def _pick_key(self, key: BasePackage) -> str: + for name in key.all_names(): + if name in self.packages: + return name + return key.name + + def __contains__(self, key: BasePackage): + for name in key.all_names(): + if name in self.packages: + return True + + def __getitem__(self, key: BasePackage): + key_str: str = self._pick_key(key) + return self.packages[key_str] + + def __setitem__(self, key: BasePackage, value): + key_str: str = self._pick_key(key) + self.packages[key_str] = value + + def _mismatched_types( + self, old: UnpinnedPackage, new: UnpinnedPackage + ) -> NoReturn: + raise_dependency_error( + f'Cannot incorporate {new} ({new.__class__.__name__}) in {old} ' + f'({old.__class__.__name__}): mismatched types' + ) + + def incorporate(self, package: UnpinnedPackage): + key: str = self._pick_key(package) + if key in self.packages: + existing: UnpinnedPackage = self.packages[key] + if not isinstance(existing, type(package)): + self._mismatched_types(existing, package) + self.packages[key] = existing.incorporate(package) + else: + self.packages[key] = package + + def update_from(self, src: List[PackageContract]) -> None: + pkg: UnpinnedPackage + for contract in src: + if isinstance(contract, LocalPackage): + pkg = LocalUnpinnedPackage.from_contract(contract) + elif isinstance(contract, GitPackage): + pkg = GitUnpinnedPackage.from_contract(contract) + elif isinstance(contract, RegistryPackage): + pkg = RegistryUnpinnedPackage.from_contract(contract) + else: + raise InternalException( + 'Invalid package type {}'.format(type(contract)) + ) + self.incorporate(pkg) + + @classmethod + def from_contracts( + cls: Type['PackageListing'], src: List[PackageContract] + ) -> 'PackageListing': + self = cls({}) + self.update_from(src) + return self + + def resolved(self) -> List[PinnedPackage]: + return [p.resolved() for p in self.packages.values()] + + def __iter__(self) -> Iterator[UnpinnedPackage]: + return iter(self.packages.values()) + + +def _check_for_duplicate_project_names( + final_deps: List[PinnedPackage], config +): + seen = set() + for package in final_deps: + project_name = package.get_project_name(config) + if project_name in seen: + raise_dependency_error( + 'Found duplicate project {}. This occurs when a dependency' + ' has the same project name as some other dependency.' + .format(project_name)) + seen.add(project_name) + + +def resolve_packages( + packages: List[PackageContract], config +) -> List[PinnedPackage]: + pending = PackageListing.from_contracts(packages) + final = PackageListing() + + while pending: + next_pending = PackageListing() + # resolve the dependency in question + for package in pending: + final.incorporate(package) + target = final[package].resolved().fetch_metadata(config) + next_pending.update_from(target.packages) + pending = next_pending + + resolved = final.resolved() + _check_for_duplicate_project_names(resolved, config) + return resolved diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index bc048ae00e8..92b9d9f3fd0 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -169,6 +169,11 @@ class RPCCompiling(RuntimeException): ' compile status' ) + def __init__(self, msg=None, node=None): + if msg is None: + msg = 'compile in progress' + super().__init__(msg, node) + class RPCLoadException(RuntimeException): CODE = 10011 diff --git a/core/dbt/main.py b/core/dbt/main.py index fcd7d504fd0..7f79afe4c3f 100644 --- a/core/dbt/main.py +++ b/core/dbt/main.py @@ -364,7 +364,7 @@ def _build_deps_subparser(subparsers, base_subparser): Pull the most recent version of the dependencies listed in packages.yml ''' ) - sub.set_defaults(cls=deps_task.DepsTask, which='deps', rpc_method=None) + sub.set_defaults(cls=deps_task.DepsTask, which='deps', rpc_method='deps') return sub diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index dd05dffa1d6..ec752b809ca 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -303,7 +303,7 @@ def parse_test( '\n\t{}\n\t@: {}' .format(block.path.original_file_path, exc.msg, context) ) - raise CompilationException(msg) + raise CompilationException(msg) from exc def _calculate_freshness( self, diff --git a/core/dbt/rpc/logger.py b/core/dbt/rpc/logger.py index 508ad27772b..cfe006cfad1 100644 --- a/core/dbt/rpc/logger.py +++ b/core/dbt/rpc/logger.py @@ -7,22 +7,15 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta from queue import Empty -from typing import Optional, Any, Union +from typing import Optional, Any from dbt.contracts.rpc import ( - RemoteCompileResult, RemoteExecutionResult, RemoteCatalogResults + RemoteResult, ) from dbt.exceptions import InternalException from dbt.utils import restrict_to -RPCResult = Union[ - RemoteCompileResult, - RemoteExecutionResult, - RemoteCatalogResults, -] - - class QueueMessageType(StrEnum): Error = 'error' Result = 'result' @@ -73,10 +66,10 @@ class QueueResultMessage(QueueMessage): message_type: QueueMessageType = field( metadata=restrict_to(QueueMessageType.Result) ) - result: RPCResult + result: RemoteResult @classmethod - def from_result(cls, result: RPCResult): + def from_result(cls, result: RemoteResult): return cls( message_type=QueueMessageType.Result, result=result, @@ -101,7 +94,7 @@ def emit(self, record: logbook.LogRecord): def emit_error(self, error: JSONRPCError): self.queue.put_nowait(QueueErrorMessage.from_error(error)) - def emit_result(self, result: RPCResult): + def emit_result(self, result: RemoteResult): self.queue.put_nowait(QueueResultMessage.from_result(result)) diff --git a/core/dbt/rpc/method.py b/core/dbt/rpc/method.py index 29521c7246f..c9e0d77d07c 100644 --- a/core/dbt/rpc/method.py +++ b/core/dbt/rpc/method.py @@ -1,29 +1,26 @@ import inspect from abc import abstractmethod from typing import List, Optional, Type, TypeVar, Generic +from typing import Any # noqa -from dbt.contracts.rpc import RPCParameters +from dbt.contracts.rpc import RPCParameters, RemoteResult, RemoteMethodFlags from dbt.exceptions import NotImplementedException, InternalException -from dbt.rpc.logger import RPCResult - Parameters = TypeVar('Parameters', bound=RPCParameters) -Result = TypeVar('Result', bound=RPCResult) +Result = TypeVar('Result', bound=RemoteResult) -# If you call recursive_subclasses on a subclass of RemoteMethod, it should +# If you call recursive_subclasses on a subclass of BaseRemoteMethod, it should # only return subtypes of the given subclass. T = TypeVar('T', bound='RemoteMethod') class RemoteMethod(Generic[Parameters, Result]): METHOD_NAME: Optional[str] = None - is_async = False - def __init__(self, args, config, manifest): + def __init__(self, args, config): self.args = args self.config = config - self.manifest = manifest @classmethod def get_parameters(cls) -> Type[Parameters]: @@ -48,9 +45,13 @@ def get_parameters(cls) -> Type[Parameters]: ) return params_type + def get_flags(self) -> RemoteMethodFlags: + return RemoteMethodFlags.Empty + @classmethod def recursive_subclasses( - cls: Type[T], named_only: bool = True + cls: Type[T], + named_only: bool = True, ) -> List[Type[T]]: classes = [] current = [cls] @@ -65,12 +66,51 @@ def recursive_subclasses( @abstractmethod def set_args(self, params: Parameters): - raise NotImplementedException( - 'set_args not implemented' - ) + """set_args executes in the parent process for an RPC call""" + raise NotImplementedException('set_args not implemented') @abstractmethod def handle_request(self) -> Result: - raise NotImplementedException( - 'handle_request not implemented' - ) + """handle_request executes inside the child process for an RPC call""" + raise NotImplementedException('handle_request not implemented') + + def cleanup(self, result: Optional[Result]): + """cleanup is an optional method that executes inside the parent + process for an RPC call. + + This will always be executed if set_args was. + + It's optional, and by default it does nothing. + """ + + def set_config(self, config): + self.config = config + + +class RemoteManifestMethod(RemoteMethod[Parameters, Result]): + def __init__(self, args, config, manifest): + super().__init__(args, config) + self.manifest = manifest + + +class TaskList(List[Type[RemoteMethod]]): + def __init__( + self, + tasks: Optional[List[Type[RemoteMethod]]] = None + ): + task_list: List[Type[RemoteMethod]] + if tasks is None: + task_list = RemoteMethod.recursive_subclasses(named_only=True) + else: + task_list = tasks + return super().__init__(task_list) + + def manifest(self) -> List[Type[RemoteManifestMethod]]: + return [ + t for t in self if issubclass(t, RemoteManifestMethod) + ] + + def non_manifest(self) -> List[Type[RemoteMethod]]: + return [ + t for t in self if not issubclass(t, RemoteManifestMethod) + ] diff --git a/core/dbt/rpc/task_handler.py b/core/dbt/rpc/task_handler.py index 5a80c3ea679..a45b963d069 100644 --- a/core/dbt/rpc/task_handler.py +++ b/core/dbt/rpc/task_handler.py @@ -1,22 +1,26 @@ -import multiprocessing import signal import sys import threading import uuid +from contextlib import contextmanager from datetime import datetime -from typing import Any, Dict, Union, Optional, List, Type +from typing import ( + Any, Dict, Union, Optional, List, Type, Callable, Iterator +) +from typing_extensions import Protocol from hologram import JsonSchemaMixin, ValidationError -from hologram.helpers import StrEnum import dbt.exceptions import dbt.flags from dbt.adapters.factory import ( - cleanup_connections, load_plugin, register_adapter + cleanup_connections, load_plugin, register_adapter, +) +from dbt.contracts.rpc import ( + RPCParameters, RemoteResult, TaskHandlerState, RemoteMethodFlags, ) -from dbt.contracts.rpc import RPCParameters from dbt.logger import ( - GLOBAL_LOGGER as logger, list_handler, LogMessage, OutputHandler + GLOBAL_LOGGER as logger, list_handler, LogMessage, OutputHandler, ) from dbt.rpc.error import ( dbt_error, @@ -25,7 +29,6 @@ timeout_error, ) from dbt.rpc.logger import ( - RPCResult, QueueSubscriber, QueueLogHandler, QueueErrorMessage, @@ -42,120 +45,147 @@ SINGLE_THREADED_HANDLER = env_set_truthy('DBT_SINGLE_THREADED_HANDLER') -class TaskHandlerState(StrEnum): - NotStarted = 'not started' - Initializing = 'initializing' - Running = 'running' - Success = 'success' - Error = 'error' +def sigterm_handler(signum, frame): + raise dbt.exceptions.RPCKilledException(signum) - def __lt__(self, other) -> bool: - """A logical ordering for TaskHandlerState: - NotStarted < Initializing < Running < (Success, Error) +class BootstrapProcess(dbt.flags.MP_CONTEXT.Process): + def __init__( + self, + task: RemoteMethod, + queue, # typing: Queue[Tuple[QueueMessageType, Any]] + ) -> None: + self.task = task + self.queue = queue + super().__init__() + + def _spawn_setup(self): """ - if not isinstance(other, TaskHandlerState): - raise TypeError('cannot compare to non-TaskHandlerState') - order = (self.NotStarted, self.Initializing, self.Running) - smaller = set() - for value in order: - smaller.add(value) - if self == value: - return other not in smaller + Because we're using spawn, we have to do a some things that dbt does + dynamically at process load. - return False + These things are inherited automatically in fork mode, where fork() + keeps everything in memory. + """ + # reset flags + dbt.flags.set_from_args(self.task.args) + # reload the active plugin + load_plugin(self.task.config.credentials.type) + # register it + register_adapter(self.task.config) + + # reset tracking, etc + self.task.config.config.set_values(self.task.args.profiles_dir) + + def task_exec(self) -> None: + """task_exec runs first inside the child process""" + signal.signal(signal.SIGTERM, sigterm_handler) + # the first thing we do in a new process: push logging back over our + # queue + handler = QueueLogHandler(self.queue) + with handler.applicationbound(): + self._spawn_setup() + rpc_exception = None + result = None + try: + result = self.task.handle_request() + except RPCException as exc: + rpc_exception = exc + except dbt.exceptions.RPCKilledException as exc: + # do NOT log anything here, you risk triggering a deadlock on + # the queue handler we inserted above + rpc_exception = dbt_error(exc) + except dbt.exceptions.Exception as exc: + logger.debug('dbt runtime exception', exc_info=True) + rpc_exception = dbt_error(exc) + except Exception as exc: + with OutputHandler(sys.stderr).applicationbound(): + logger.error('uncaught python exception', exc_info=True) + rpc_exception = server_error(exc) + + # put whatever result we got onto the queue as well. + if rpc_exception is not None: + handler.emit_error(rpc_exception.error) + elif result is not None: + handler.emit_result(result) + else: + error = dbt_error(dbt.exceptions.InternalException( + 'after request handling, neither result nor error is None!' + )) + handler.emit_error(error.error) - def __le__(self, other) -> bool: - # so that ((Success <= Error) is True) - return ((self < other) or - (self == other) or - (self.finished and other.finished)) - - def __gt__(self, other) -> bool: - if not isinstance(other, TaskHandlerState): - raise TypeError('cannot compare to non-TaskHandlerState') - order = (self.NotStarted, self.Initializing, self.Running) - smaller = set() - for value in order: - smaller.add(value) - if self == value: - return other in smaller - return other in smaller - - def __ge__(self, other) -> bool: - # so that ((Success <= Error) is True) - return ((self > other) or - (self == other) or - (self.finished and other.finished)) + def run(self): + self.task_exec() - @property - def finished(self) -> bool: - return self == self.Error or self == self.Success +class TaskManagerProtocol(Protocol): + config: Any -def sigterm_handler(signum, frame): - raise dbt.exceptions.RPCKilledException(signum) + def set_parsing(self): + pass + def set_compile_exception( + self, exc: Exception, logs: List[Dict[str, Any]] + ): + pass -def _spawn_setup(config, args): - """ - Because we're using spawn, we have to do a some things that dbt does - dynamically at process load. + def set_ready(self, logs: List[Dict[str, Any]]): + pass + + def add_request(self, request: 'RequestTaskHandler') -> Dict[str, Any]: + pass + + def parse_manifest(self): + pass + + def reload_config(self): + pass - These things are inherited automatically in fork mode, where fork() keeps - everything in memory. + +@contextmanager +def set_parse_state_with( + manager: TaskManagerProtocol, + logs: Callable[[], List[LogMessage]], +) -> Iterator[None]: + """Given a task manager and either a list of logs or a callable that + returns said list, set appropriate state on the manager upon exiting. """ - # reset flags - dbt.flags.set_from_args(args) - # reload the active plugin - load_plugin(config.credentials.type) - # register it - register_adapter(config) - - # reset tracking, etc - config.config.set_values(args.profiles_dir) - - -def _task_bootstrap( - task: RemoteMethod, - queue, # typing: Queue[Tuple[QueueMessageType, Any]] - params: JsonSchemaMixin, -) -> None: - """_task_bootstrap runs first inside the child process""" - signal.signal(signal.SIGTERM, sigterm_handler) - # the first thing we do in a new process: push logging back over our queue - handler = QueueLogHandler(queue) - with handler.applicationbound(): - _spawn_setup(task.config, task.args) - rpc_exception = None - result = None - try: - task.set_args(params=params) - result = task.handle_request() - except RPCException as exc: - rpc_exception = exc - except dbt.exceptions.RPCKilledException as exc: - # do NOT log anything here, you risk triggering a deadlock on the - # queue handler we inserted above - rpc_exception = dbt_error(exc) - except dbt.exceptions.Exception as exc: - logger.debug('dbt runtime exception', exc_info=True) - rpc_exception = dbt_error(exc) - except Exception as exc: - with OutputHandler(sys.stderr).applicationbound(): - logger.error('uncaught python exception', exc_info=True) - rpc_exception = server_error(exc) - - # put whatever result we got onto the queue as well. - if rpc_exception is not None: - handler.emit_error(rpc_exception.error) - elif result is not None: - handler.emit_result(result) - else: - error = dbt_error(dbt.exceptions.InternalException( - 'after request handling, neither result nor error is None!' - )) - handler.emit_error(error.error) + try: + yield + except Exception as exc: + log_dicts = [r.to_dict() for r in logs()] + manager.set_compile_exception(exc, logs=log_dicts) + # re-raise to ensure any exception handlers above trigger. We might be + # in an API call that set the parse state, in which case we don't want + # to swallow the exception - it also should report its failure to the + # task manager. + raise + else: + log_dicts = [r.to_dict() for r in logs()] + manager.set_ready(log_dicts) + + +@contextmanager +def _noop_context() -> Iterator[None]: + yield + + +@contextmanager +def get_results_context( + flags: RemoteMethodFlags, + manager: TaskManagerProtocol, + logs: Callable[[], List[LogMessage]] +) -> Iterator[None]: + + if RemoteMethodFlags.BlocksManifestTasks in flags: + manifest_blocking = set_parse_state_with(manager, logs) + else: + manifest_blocking = _noop_context() + + with manifest_blocking: + yield + if RemoteMethodFlags.RequiresManifestReloadAfter: + manager.parse_manifest() class StateHandler: @@ -169,50 +199,80 @@ def __enter__(self) -> None: def set_end(self): self.handler.ended = datetime.utcnow() - def handle_success(self): - self.handler.state = TaskHandlerState.Success + def handle_completed(self): + # killed handlers don't get a result. + if self.handler.state != TaskHandlerState.Killed: + if self.handler.result is None: + # there wasn't an error before, but there sure is one now + self.handler.error = dbt_error( + dbt.exceptions.InternalException( + 'got an invalid result=None, but state was {}' + .format(self.handler.state) + ) + ) + elif self.handler.task.interpret_results(self.handler.result): + self.handler.state = TaskHandlerState.Success + else: + self.handler.state = TaskHandlerState.Failed self.set_end() def handle_error(self, exc_type, exc_value, exc_tb) -> bool: if isinstance(exc_value, RPCException): self.handler.error = exc_value - self.handler.state = TaskHandlerState.Error elif isinstance(exc_value, dbt.exceptions.Exception): self.handler.error = dbt_error(exc_value) - self.handler.state = TaskHandlerState.Error else: # we should only get here if we got a BaseException that is not # an Exception (we caught those in _wait_for_results), or a bug # in get_result's call stack. Either way, we should set an # error so we can figure out what happened on thread death self.handler.error = server_error(exc_value) + if self.handler.state != TaskHandlerState.Killed: self.handler.state = TaskHandlerState.Error self.set_end() return False - def __exit__(self, exc_type, exc_value, exc_tb) -> bool: - if exc_type is not None: - return self.handle_error(exc_type, exc_value, exc_tb) - - self.handle_success() - return False + def task_teardown(self): + self.handler.task.cleanup(self.handler.result) + def __exit__(self, exc_type, exc_value, exc_tb) -> bool: + try: + if exc_type is not None: + self.handle_error(exc_type, exc_value, exc_tb) + else: + self.handle_completed() + return False + finally: + # we really really promise to run your teardown + self.task_teardown() + + +class SetArgsStateHandler(StateHandler): + """A state handler that does not touch state on success and does not + execute the teardown + """ + def handle_completed(self): + pass -class ErrorOnlyStateHandler(StateHandler): - """A state handler that does not touch state on success.""" - def handle_success(self): + def handle_teardown(self): pass class RequestTaskHandler(threading.Thread): """Handler for the single task triggered by a given jsonrpc request.""" - def __init__(self, manager, task, http_request, json_rpc_request): - self.manager = manager - self.task = task + def __init__( + self, + manager: TaskManagerProtocol, + task: RemoteMethod, + http_request, + json_rpc_request, + ) -> None: + self.manager: TaskManagerProtocol = manager + self.task: RemoteMethod = task self.http_request = http_request self.json_rpc_request = json_rpc_request self.subscriber: Optional[QueueSubscriber] = None - self.process: Optional[multiprocessing.Process] = None + self.process: Optional[BootstrapProcess] = None self.thread: Optional[threading.Thread] = None self.started: Optional[datetime] = None self.ended: Optional[datetime] = None @@ -244,11 +304,16 @@ def request_id(self) -> Union[str, int]: @property def method(self) -> str: + if self.task.METHOD_NAME is None: # mypy appeasement + raise dbt.exceptions.InternalException( + f'In the request handler, got a task({self.task}) with no ' + 'METHOD_NAME' + ) return self.task.METHOD_NAME @property def _single_threaded(self): - return self.task.args.single_threaded or SINGLE_THREADED_HANDLER + return bool(self.task.args.single_threaded or SINGLE_THREADED_HANDLER) @property def timeout(self) -> Optional[float]: @@ -264,7 +329,7 @@ def tags(self) -> Optional[Dict[str, Any]]: return None return self.task_params.task_tags - def _wait_for_results(self) -> RPCResult: + def _wait_for_results(self) -> RemoteResult: """Wait for results off the queue. If there is an exception raised, raise an appropriate RPC exception. @@ -302,43 +367,57 @@ def _wait_for_results(self) -> RPCResult: 'Invalid message type {} (result={})'.format(msg) ) - def get_result(self) -> RPCResult: + def get_result(self) -> RemoteResult: if self.process is None: raise dbt.exceptions.InternalException( 'get_result() called before handle()' ) - try: - with list_handler(self.logs): - try: - result = self._wait_for_results() - finally: - if not self._single_threaded: - self.process.join() - except RPCException as exc: - # RPC Exceptions come already preserialized for the jsonrpc - # framework - exc.logs = [l.to_dict() for l in self.logs] - exc.tags = self.tags - raise - - # results get real logs - result.logs = self.logs[:] - return result + flags = self.task.get_flags() + + # If we blocked the manifest tasks, we need to un-set them on exit. + # threaded mode handles this on its own. + with get_results_context(flags, self.manager, lambda: self.logs): + try: + with list_handler(self.logs): + try: + result = self._wait_for_results() + finally: + if not self._single_threaded: + self.process.join() + except RPCException as exc: + # RPC Exceptions come already preserialized for the jsonrpc + # framework + exc.logs = [l.to_dict() for l in self.logs] + exc.tags = self.tags + raise + + # results get real logs + result.logs = self.logs[:] + return result def run(self): try: with StateHandler(self): self.result = self.get_result() - except RPCException: - pass # rpc exceptions are fine, the managing thread will handle it - - def handle_singlethreaded(self, kwargs): + except (dbt.exceptions.Exception, RPCException): + # we probably got an error after the RPC call ran (and it was + # probably deps...). By now anyone who wanted to see it has seen it + # so we can suppress it to avoid stderr stack traces + pass + + def handle_singlethreaded( + self, kwargs: Dict[str, Any], flags: RemoteMethodFlags + ): # in single-threaded mode, we're going to remain synchronous, so call # `run`, not `start`, and return an actual result. # note this shouldn't call self.run() as that has different semantics # (we want errors to raise) - self.process.run() + if self.process is None: # mypy appeasement + raise dbt.exceptions.InternalException( + 'Cannot run a None process' + ) + self.process.task_exec() with StateHandler(self): self.result = self.get_result() return self.result @@ -373,23 +452,36 @@ def handle(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: self.started = datetime.utcnow() self.state = TaskHandlerState.Initializing self.task_kwargs = kwargs - with ErrorOnlyStateHandler(self): + + with SetArgsStateHandler(self): # this will raise a TypeError if you provided bad arguments. self.task_params = self._collect_parameters() - if self.task_params is None: - raise dbt.exceptions.InternalException( - 'Task params set to None!' - ) + self.task.set_args(self.task_params) + # now that we have called set_args, we can figure out our flags + flags: RemoteMethodFlags = self.task.get_flags() + if RemoteMethodFlags.RequiresConfigReloadBefore in flags: + # tell the manager to reload the config. + self.manager.reload_config() + # set our task config to the version on our manager now. RPCCLi + # tasks use this to set their `real_task`. + self.task.set_config(self.manager.config) + if self.task_params is None: # mypy appeasement + raise dbt.exceptions.InternalException( + 'Task params set to None!' + ) + self.subscriber = QueueSubscriber(dbt.flags.MP_CONTEXT.Queue()) - self.process = dbt.flags.MP_CONTEXT.Process( - target=_task_bootstrap, - args=(self.task, self.subscriber.queue, self.task_params) - ) + self.process = BootstrapProcess(self.task, self.subscriber.queue) + + if RemoteMethodFlags.BlocksManifestTasks in flags: + # got a request to do some compiling, but we already are! + if not self.manager.set_parsing(): + raise dbt_error(dbt.exceptions.RPCCompiling()) if self._single_threaded: # all requests are synchronous in single-threaded mode. No need to # create a process... - return self.handle_singlethreaded(kwargs) + return self.handle_singlethreaded(kwargs, flags) self.start() return {'request_token': str(self.task_id)} diff --git a/core/dbt/rpc/task_manager.py b/core/dbt/rpc/task_manager.py index 8bd518997b5..29f1da8a170 100644 --- a/core/dbt/rpc/task_manager.py +++ b/core/dbt/rpc/task_manager.py @@ -1,8 +1,9 @@ import operator import os import signal +import threading import uuid -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime, timedelta from functools import wraps from typing import ( @@ -10,26 +11,50 @@ ) from hologram import JsonSchemaMixin, ValidationError -from hologram.helpers import StrEnum import dbt.exceptions import dbt.flags from dbt.contracts.graph.manifest import Manifest from dbt.contracts.rpc import ( - RemoteCompileResult, - RemoteRunResult, + TaskTags, + LastParse, + ManifestStatus, + GCSettings, + KillResult, + KillResultStatus, + GCResultState, + GCResultSet, + TaskRow, + PSResult, RemoteExecutionResult, + RemoteRunResult, + RemoteCompileResult, RemoteCatalogResults, + RemoteEmptyResult, + PollResult, + PollInProgressResult, + PollKilledResult, + PollExecuteCompleteResult, + PollRunCompleteResult, + PollCompileCompleteResult, + PollCatalogCompleteResult, + PollRemoteEmptyCompleteResult, ) -from dbt.logger import LogMessage +from dbt.logger import LogMessage, list_handler +from dbt.perf_utils import get_full_manifest from dbt.rpc.error import dbt_error, RPCException -from dbt.rpc.task_handler import TaskHandlerState, RequestTaskHandler -from dbt.rpc.method import RemoteMethod +from dbt.rpc.task_handler import ( + TaskHandlerState, RequestTaskHandler, set_parse_state_with +) +from dbt.rpc.method import RemoteMethod, RemoteManifestMethod, TaskList -from dbt.utils import restrict_to # import this to make sure our timedelta encoder is registered from dbt import helper_types # noqa +from dbt.utils import env_set_truthy + + +SINGLE_THREADED_WEBSERVER = env_set_truthy('DBT_SINGLE_THREADED_WEBSERVER') def _assert_started(task_handler: RequestTaskHandler) -> datetime: @@ -48,309 +73,203 @@ def _assert_ended(task_handler: RequestTaskHandler) -> datetime: return task_handler.ended -@dataclass -class TaskRow(JsonSchemaMixin): - task_id: uuid.UUID - request_id: Union[str, int] - request_source: str - method: str - state: TaskHandlerState - start: Optional[datetime] - end: Optional[datetime] - elapsed: Optional[float] - timeout: Optional[float] - tags: Optional[Dict[str, Any]] - - @classmethod - def from_task(cls, task_handler: RequestTaskHandler, now_time: datetime): - # get information about the task in a way that should not provide any - # conflicting information. Calculate elapsed time based on `now_time` - state = task_handler.state - # store end/start so 'ps' output always makes sense: - # not started -> no start time/elapsed, running -> no end time, etc - end = None - start = None - elapsed = None - if state > TaskHandlerState.NotStarted: - start = _assert_started(task_handler) - elapsed_end = now_time - - if state.finished: - elapsed_end = _assert_ended(task_handler) - end = elapsed_end - - elapsed = (elapsed_end - start).total_seconds() - - return cls( - task_id=task_handler.task_id, - request_id=task_handler.request_id, - request_source=task_handler.request_source, - method=task_handler.method, - state=state, - start=start, - end=end, - elapsed=elapsed, - timeout=task_handler.timeout, - tags=task_handler.tags, - ) - - -class KillResultStatus(StrEnum): - Missing = 'missing' - NotStarted = 'not_started' - Killed = 'killed' - Finished = 'finished' - - -@dataclass -class KillResult(JsonSchemaMixin): - status: KillResultStatus +def make_task(task_handler: RequestTaskHandler, now_time: datetime) -> TaskRow: + # get information about the task in a way that should not provide any + # conflicting information. Calculate elapsed time based on `now_time` + state = task_handler.state + # store end/start so 'ps' output always makes sense: + # not started -> no start time/elapsed, running -> no end time, etc + end = None + start = None + elapsed = None + if state > TaskHandlerState.NotStarted: + start = _assert_started(task_handler) + elapsed_end = now_time + + if state.finished: + elapsed_end = _assert_ended(task_handler) + end = elapsed_end + + elapsed = (elapsed_end - start).total_seconds() + + return TaskRow( + task_id=task_handler.task_id, + request_id=task_handler.request_id, + request_source=task_handler.request_source, + method=task_handler.method, + state=state, + start=start, + end=end, + elapsed=elapsed, + timeout=task_handler.timeout, + tags=task_handler.tags, + ) -@dataclass -class PollResult(JsonSchemaMixin): - tags: Optional[Dict[str, Any]] = None - status: TaskHandlerState = TaskHandlerState.NotStarted +UnmanagedHandler = Callable[..., JsonSchemaMixin] +WrappedHandler = Callable[..., Dict[str, Any]] -class GCResultState(StrEnum): - Deleted = 'deleted' # successful GC - Missing = 'missing' # nothing to GC - Running = 'running' # can't GC +def _wrap_builtin(func: UnmanagedHandler) -> WrappedHandler: + @wraps(func) + def inner(*args, **kwargs): + return func(*args, **kwargs).to_dict(omit_none=False) + return inner -@dataclass -class _GCResult(JsonSchemaMixin): - task_id: uuid.UUID - status: GCResultState +class Reserved: + # a dummy class + pass -@dataclass -class GCResultSet(JsonSchemaMixin): - deleted: List[uuid.UUID] = field(default_factory=list) - missing: List[uuid.UUID] = field(default_factory=list) - running: List[uuid.UUID] = field(default_factory=list) - - def add_result(self, result: _GCResult): - if result.status == GCResultState.Missing: - self.missing.append(result.task_id) - elif result.status == GCResultState.Running: - self.running.append(result.task_id) - elif result.status == GCResultState.Deleted: - self.deleted.append(result.task_id) - else: - raise dbt.exceptions.InternalException( - 'Got invalid _GCResult in add_result: {!r}' - .format(result) - ) +class ManifestReloader(threading.Thread): + def __init__(self, task_manager: 'TaskManager') -> None: + super().__init__() + self.task_manager = task_manager + def reload_manifest(self): + logs: List[LogMessage] = [] + with set_parse_state_with(self.task_manager, lambda: logs): + with list_handler(logs): + self.task_manager.parse_manifest() -@dataclass -class GCSettings(JsonSchemaMixin): - # start evicting the longest-ago-ended tasks here - maxsize: int - # start evicting all tasks before now - auto_reap_age when we have this - # many tasks in the table - reapsize: int - # a positive timedelta indicating how far back we should go - auto_reap_age: timedelta + def run(self) -> None: + try: + self.reload_manifest() + except Exception: + # ignore ugly thread-death error messages to stderr + pass @dataclass class _GCArguments(JsonSchemaMixin): + """An argument validation helper""" task_ids: Optional[List[uuid.UUID]] before: Optional[datetime] settings: Optional[GCSettings] -TaskTags = Optional[Dict[str, Any]] - - -@dataclass -class PollExecuteSuccessResult(PollResult, RemoteExecutionResult): - status: TaskHandlerState = field( - metadata=restrict_to(TaskHandlerState.Success), - default=TaskHandlerState.Success, - ) - - @classmethod - def from_result( - cls: Type['PollExecuteSuccessResult'], - status: TaskHandlerState, - base: RemoteExecutionResult, - tags: TaskTags, - ) -> 'PollExecuteSuccessResult': - return cls( - status=status, - results=base.results, - generated_at=base.generated_at, - elapsed_time=base.elapsed_time, - logs=base.logs, - tags=tags, - ) - - -@dataclass -class PollCompileSuccessResult(PollResult, RemoteCompileResult): - status: TaskHandlerState = field( - metadata=restrict_to(TaskHandlerState.Success), - default=TaskHandlerState.Success, - ) - - @classmethod - def from_result( - cls: Type['PollCompileSuccessResult'], - status: TaskHandlerState, - base: RemoteCompileResult, - tags: TaskTags, - ) -> 'PollCompileSuccessResult': - return cls( - status=status, - raw_sql=base.raw_sql, - compiled_sql=base.compiled_sql, - node=base.node, - timing=base.timing, - logs=base.logs, - tags=tags, - ) - - -@dataclass -class PollRunSuccessResult(PollResult, RemoteRunResult): - status: TaskHandlerState = field( - metadata=restrict_to(TaskHandlerState.Success), - default=TaskHandlerState.Success, - ) - - @classmethod - def from_result( - cls: Type['PollRunSuccessResult'], - status: TaskHandlerState, - base: RemoteRunResult, - tags: TaskTags, - ) -> 'PollRunSuccessResult': - return cls( - status=status, - raw_sql=base.raw_sql, - compiled_sql=base.compiled_sql, - node=base.node, - timing=base.timing, - logs=base.logs, - table=base.table, - tags=tags, - ) - - -@dataclass -class PollCatalogSuccessResult(PollResult, RemoteCatalogResults): - status: TaskHandlerState = field( - metadata=restrict_to(TaskHandlerState.Success), - default=TaskHandlerState.Success, - ) - - @classmethod - def from_result( - cls: Type['PollCatalogSuccessResult'], - status: TaskHandlerState, - base: RemoteCatalogResults, - tags: TaskTags, - ) -> 'PollCatalogSuccessResult': - return cls( - status=status, - nodes=base.nodes, - generated_at=base.generated_at, - _compile_results=base._compile_results, - logs=base.logs, - tags=tags, - ) - - -def poll_success( +def poll_complete( status: TaskHandlerState, result: Any, tags: TaskTags ) -> PollResult: - if status != TaskHandlerState.Success: + if status not in (TaskHandlerState.Success, TaskHandlerState.Failed): raise dbt.exceptions.InternalException( - 'got invalid result status in poll_success: {}'.format(status) + 'got invalid result status in poll_complete: {}'.format(status) ) + cls: Type[Union[ + PollExecuteCompleteResult, + PollRunCompleteResult, + PollCompileCompleteResult, + PollCatalogCompleteResult, + PollRemoteEmptyCompleteResult, + ]] + if isinstance(result, RemoteExecutionResult): - return PollExecuteSuccessResult.from_result(status, result, tags) + cls = PollExecuteCompleteResult # order matters here, as RemoteRunResult subclasses RemoteCompileResult elif isinstance(result, RemoteRunResult): - return PollRunSuccessResult.from_result(status, result, tags) + cls = PollRunCompleteResult elif isinstance(result, RemoteCompileResult): - return PollCompileSuccessResult.from_result(status, result, tags) + cls = PollCompileCompleteResult elif isinstance(result, RemoteCatalogResults): - return PollCatalogSuccessResult.from_result(status, result, tags) + cls = PollCatalogCompleteResult + elif isinstance(result, RemoteEmptyResult): + cls = PollRemoteEmptyCompleteResult else: raise dbt.exceptions.InternalException( - 'got invalid result in poll_success: {}'.format(result) + 'got invalid result in poll_complete: {}'.format(result) ) - - -@dataclass -class PollInProgressResult(PollResult): - logs: List[LogMessage] = field(default_factory=list) - - -@dataclass -class PSResult(JsonSchemaMixin): - rows: List[TaskRow] - - -class ManifestStatus(StrEnum): - Init = 'init' - Compiling = 'compiling' - Ready = 'ready' - Error = 'error' - - -@dataclass -class LastCompile(JsonSchemaMixin): - status: ManifestStatus - error: Optional[Dict[str, Any]] = None - logs: Optional[List[Dict[str, Any]]] = None - timestamp: datetime = field(default_factory=datetime.utcnow) - pid: int = field(default_factory=os.getpid) - - -UnmanagedHandler = Callable[..., JsonSchemaMixin] -WrappedHandler = Callable[..., Dict[str, Any]] - - -def _wrap_builtin(func: UnmanagedHandler) -> WrappedHandler: - @wraps(func) - def inner(*args, **kwargs): - return func(*args, **kwargs).to_dict(omit_none=False) - return inner + return cls.from_result(status, result, tags) class TaskManager: - def __init__(self, args, config): + def __init__(self, args, config, task_types: TaskList) -> None: self.args = args self.config = config - self.tasks: Dict[uuid.UUID, RequestTaskHandler] = {} - self._rpc_task_map: Dict[str, RemoteMethod] = {} + self._task_types: TaskList = task_types + self.active_tasks: Dict[uuid.UUID, RequestTaskHandler] = {} + self._rpc_task_map: Dict[str, Union[Reserved, RemoteMethod]] = {} self._builtins: Dict[str, UnmanagedHandler] = {} - self.last_compile = LastCompile(status=ManifestStatus.Init) + self.last_parse: LastParse = LastParse(status=ManifestStatus.Init) self._lock: dbt.flags.MP_CONTEXT.Lock = dbt.flags.MP_CONTEXT.Lock() - self._gc_settings = GCSettings( + self._gc_settings: GCSettings = GCSettings( maxsize=1000, reapsize=500, auto_reap_age=timedelta(days=30) ) + self._reloader: Optional[ManifestReloader] = None + + def _reload_task_manager_thread(self, reloader: ManifestReloader): + """This function can only be running once at a time, as it runs in the + signal handler we replace + """ + # compile in a thread that will fix up the tag manager when it's done + reloader.start() + # only assign to _reloader here, to avoid calling join() before start() + self._reloader = reloader + + def _reload_task_manager_fg(self, reloader: ManifestReloader): + """Override for single-threaded mode to run in the foreground""" + # just reload directly + reloader.reload_manifest() + + def reload_manifest_tasks(self) -> bool: + """Reload the manifest using a manifest reloader. Returns False if the + reload was not started because it was already running. + """ + if not self.set_parsing(): + return False + if self._reloader is not None: + # join() the existing reloader + self._reloader.join() + # perform the reload + reloader = ManifestReloader(self) + if self.single_threaded(): + self._reload_task_manager_fg(reloader) + else: + self._reload_task_manager_thread(reloader) + return True - def add_request(self, request_handler): - self.tasks[request_handler.task_id] = request_handler + def single_threaded(self): + return SINGLE_THREADED_WEBSERVER or self.args.single_threaded - def reserve_handler(self, task): - self._rpc_task_map[task.METHOD_NAME] = None + def reload_non_manifest_tasks(self): + # reload all the non-manifest tasks because the config changed. + # manifest tasks are still blocked so we can ignore them + for task_cls in self._task_types.non_manifest(): + self.add_basic_task_handler(task_cls) - def _assert_unique_task(self, task_type: Type[RemoteMethod]): + def reload_config(self): + config = self.config.from_args(self.args) + self.config = config + # reload all the non-manifest tasks because the config changed. + # manifest tasks are still blocked so we can ignore them + self.reload_non_manifest_tasks() + return config + + def add_request(self, request_handler: RequestTaskHandler): + self.active_tasks[request_handler.task_id] = request_handler + + def reserve_handler(self, task: Type[RemoteMethod]) -> None: + """Reserved tasks will return a status indicating that the manifest is + compiling. + """ + if task.METHOD_NAME is None: + raise dbt.exceptions.InternalException( + f'Cannot add task {task} as it has no method name' + ) + self._rpc_task_map[task.METHOD_NAME] = Reserved() + + def _check_task_handler(self, task_type: Type[RemoteMethod]) -> None: + if task_type.METHOD_NAME is None: + raise dbt.exceptions.InternalException( + 'Task {} has no method name, cannot add it'.format(task_type) + ) method = task_type.METHOD_NAME if method not in self._rpc_task_map: # this is weird, but hey whatever return other_task = self._rpc_task_map[method] - if other_task is None or type(other_task) is task_type: + if isinstance(other_task, Reserved) or type(other_task) is task_type: return raise dbt.exceptions.InternalException( 'Got two tasks with the same method name! {0} and {1} both ' @@ -358,51 +277,72 @@ def _assert_unique_task(self, task_type: Type[RemoteMethod]): 'should be unique'.format(task_type, other_task) ) - def add_task_handler(self, task: Type[RemoteMethod], manifest: Manifest): - if task.METHOD_NAME is None: - raise dbt.exceptions.InternalException( - 'Task {} has no method name, cannot add it'.format(task) - ) - self._assert_unique_task(task) + def add_manifest_task_handler( + self, task: Type[RemoteManifestMethod], manifest: Manifest + ) -> None: + self._check_task_handler(task) assert task.METHOD_NAME is not None self._rpc_task_map[task.METHOD_NAME] = task( self.args, self.config, manifest ) - def rpc_task(self, method_name): + def add_basic_task_handler(self, task: Type[RemoteMethod]) -> None: + if issubclass(task, RemoteManifestMethod): + raise dbt.exceptions.InternalException( + f'Task {task} requires a manifest, cannot add it as a basic ' + f'handler' + ) + + self._check_task_handler(task) + assert task.METHOD_NAME is not None + self._rpc_task_map[task.METHOD_NAME] = task(self.args, self.config) + + def rpc_task(self, method_name: str) -> Union[Reserved, RemoteMethod]: with self._lock: return self._rpc_task_map[method_name] - def ready(self): + def ready(self) -> bool: with self._lock: - return self.last_compile.status == ManifestStatus.Ready + return self.last_parse.status == ManifestStatus.Ready - def set_compiling(self): - assert self.last_compile.status != ManifestStatus.Compiling, \ - f'invalid state {self.last_compile.status}' + def set_parsing(self) -> bool: with self._lock: - self.last_compile = LastCompile(status=ManifestStatus.Compiling) + if self.last_parse.status == ManifestStatus.Compiling: + return False + self.last_parse = LastParse(status=ManifestStatus.Compiling) + for task in self._task_types.manifest(): + # reserve any tasks that are invalid + self.reserve_handler(task) + return True + + def parse_manifest(self) -> None: + manifest = get_full_manifest(self.config) + + for task_cls in self._task_types.manifest(): + self.add_manifest_task_handler( + task_cls, manifest + ) - def set_compile_exception(self, exc, logs=List[Dict[str, Any]]): - assert self.last_compile.status == ManifestStatus.Compiling, \ - f'invalid state {self.last_compile.status}' - self.last_compile = LastCompile( + def set_compile_exception(self, exc, logs=List[Dict[str, Any]]) -> None: + assert self.last_parse.status == ManifestStatus.Compiling, \ + f'invalid state {self.last_parse.status}' + self.last_parse = LastParse( error={'message': str(exc)}, status=ManifestStatus.Error, logs=logs ) def set_ready(self, logs=List[Dict[str, Any]]) -> None: - assert self.last_compile.status == ManifestStatus.Compiling, \ - f'invalid state {self.last_compile.status}' - self.last_compile = LastCompile( + assert self.last_parse.status == ManifestStatus.Compiling, \ + f'invalid state {self.last_parse.status}' + self.last_parse = LastParse( status=ManifestStatus.Ready, logs=logs ) - def process_status(self) -> LastCompile: + def process_status(self) -> LastParse: with self._lock: - last_compile = self.last_compile + last_compile = self.last_parse return last_compile def process_ps( @@ -413,8 +353,8 @@ def process_ps( rows = [] now = datetime.utcnow() with self._lock: - for task in self.tasks.values(): - row = TaskRow.from_task(task, now) + for task in self.active_tasks.values(): + row = make_task(task, now) if row.state.finished and completed: rows.append(row) elif not row.state.finished and active: @@ -429,7 +369,7 @@ def process_kill(self, task_id: str) -> KillResult: status = KillResultStatus.Missing try: - task = self.tasks[task_id_uuid] + task: RequestTaskHandler = self.active_tasks[task_id_uuid] except KeyError: # nothing to do! return KillResult(status) @@ -443,10 +383,13 @@ def process_kill(self, task_id: str) -> KillResult: return KillResult(status) if task.process.is_alive(): - os.kill(pid, signal.SIGINT) status = KillResultStatus.Killed + task.ended = datetime.utcnow() + os.kill(pid, signal.SIGINT) + task.state = TaskHandlerState.Killed else: status = KillResultStatus.Finished + # the status must be "Completed" return KillResult(status) @@ -458,7 +401,7 @@ def process_poll( ) -> PollResult: task_id = uuid.UUID(request_token) try: - task: RequestTaskHandler = self.tasks[task_id] + task: RequestTaskHandler = self.active_tasks[task_id] except KeyError: # We don't recognize that ID. raise dbt.exceptions.UnknownAsyncIDException(task_id) from None @@ -472,12 +415,17 @@ def process_poll( # "forward-compatible" so if the state has transitioned to error/result # but we aren't there yet, the logs will still be valid. state = task.state - if state == TaskHandlerState.Error: + if state <= TaskHandlerState.Running: + return PollInProgressResult( + status=state, + tags=task.tags, + logs=task_logs, + ) + elif state == TaskHandlerState.Error: err = task.error if err is None: exc = dbt.exceptions.InternalException( - 'At end of task {}, state={} but error is None' - .format(state, task_id) + f'At end of task {task_id}, error state but error is None' ) raise RPCException.from_error( dbt_error(exc, logs=[l.to_dict() for l in task_logs]) @@ -485,27 +433,32 @@ def process_poll( # the exception has logs already attached from the child, don't # overwrite those raise err - elif state == TaskHandlerState.Success: + elif state in (TaskHandlerState.Success, TaskHandlerState.Failed): + if task.result is None: exc = dbt.exceptions.InternalException( - 'At end of task {}, state={} but result is None' - .format(state, task_id) + f'At end of task {task_id}, state={state} but result is ' + 'None' ) raise RPCException.from_error( dbt_error(exc, logs=[l.to_dict() for l in task_logs]) ) - - return poll_success( + return poll_complete( status=state, result=task.result, tags=task.tags, ) - - return PollInProgressResult( - status=state, - tags=task.tags, - logs=task_logs, - ) + elif state == TaskHandlerState.Killed: + return PollKilledResult( + status=state, tags=task.tags, logs=task_logs + ) + else: + exc = dbt.exceptions.InternalException( + f'Got unknown value state={state} for task {task_id}' + ) + raise RPCException.from_error( + dbt_error(exc, logs=[l.to_dict() for l in task_logs]) + ) def _rpc_builtins(self) -> Dict[str, UnmanagedHandler]: if self._builtins: @@ -543,9 +496,14 @@ def currently_compiling(self, *args, **kwargs): def compilation_error(self, *args, **kwargs): """Raise an RPC exception to trigger the error handler.""" raise dbt_error( - dbt.exceptions.RPCLoadException(self.last_compile.error) + dbt.exceptions.RPCLoadException(self.last_parse.error) ) + def internal_error_for(self, msg) -> WrappedHandler: + def _error(*args, **kwargs): + raise dbt.exceptions.InternalException(msg) + return _wrap_builtin(_error) + def get_handler( self, method, http_request, json_rpc_request ) -> Optional[Union[WrappedHandler, RemoteMethod]]: @@ -558,29 +516,41 @@ def get_handler( return _wrap_builtin(_builtins[method]) elif method not in self._rpc_task_map: return None - # if we have no manifest we want to return an error about why - elif self.last_compile.status == ManifestStatus.Compiling: - return self.currently_compiling - elif self.last_compile.status == ManifestStatus.Error: - return self.compilation_error + + task = self.rpc_task(method) + # If the task we got back was reserved, it must be a task that requires + # a manifest and we don't have one. So we had better have a state of + # compiling or error. + + if isinstance(task, Reserved): + status = self.last_parse.status + if status == ManifestStatus.Compiling: + return self.currently_compiling + elif status == ManifestStatus.Error: + return self.compilation_error + else: + # if we got here, there is an error in dbt :( + return self.internal_error_for( + f'Got a None task for {method}, state is {status}' + ) else: - return self.rpc_task(method) + return task def _remove_task_if_finished(self, task_id: uuid.UUID) -> GCResultState: """Remove the task if it was finished. Raises a KeyError if the entry is removed during operation (so hold the lock). """ - if task_id not in self.tasks: + if task_id not in self.active_tasks: return GCResultState.Missing - task = self.tasks[task_id] + task = self.active_tasks[task_id] if not task.state.finished: return GCResultState.Running - del self.tasks[task_id] + del self.active_tasks[task_id] return GCResultState.Deleted - def _gc_task_id(self, task_id: uuid.UUID) -> _GCResult: + def _gc_task_id(self, result: GCResultSet, task_id: uuid.UUID) -> None: """To 'gc' a task ID, we just delete it from the tasks dict. You must hold the lock, as this mutates `tasks`. @@ -595,11 +565,11 @@ def _gc_task_id(self, task_id: uuid.UUID) -> _GCResult: .format(task_id) ) - return _GCResult(task_id=task_id, status=status) + return result.add_result(task_id=task_id, status=status) def _get_gc_before_list(self, when: datetime) -> List[uuid.UUID]: removals: List[uuid.UUID] = [] - for task in self.tasks.values(): + for task in self.active_tasks.values(): if not task.state.finished: continue elif task.ended is None: @@ -611,7 +581,7 @@ def _get_gc_before_list(self, when: datetime) -> List[uuid.UUID]: def _get_oldest_ended_list(self, num: int) -> List[uuid.UUID]: candidates: List[Tuple[datetime, uuid.UUID]] = [] - for task in self.tasks.values(): + for task in self.active_tasks.values(): if not task.state.finished: continue elif task.ended is None: @@ -626,8 +596,7 @@ def _gc_multiple_task_ids( ) -> GCResultSet: result = GCResultSet() for task_id in task_ids: - gc_result = self._gc_task_id(task_id) - result.add_result(gc_result) + self._gc_task_id(result, task_id) return result def gc_safe( @@ -648,7 +617,7 @@ def gc_safe( def _gc_as_required_unsafe(self) -> None: to_remove: List[uuid.UUID] = [] - num_tasks = len(self.tasks) + num_tasks = len(self.active_tasks) if num_tasks > self._gc_settings.maxsize: num = self._gc_settings.maxsize - num_tasks to_remove = self._get_oldest_ended_list(num) diff --git a/core/dbt/task/deps.py b/core/dbt/task/deps.py index d5af4efb099..80d57a3414a 100644 --- a/core/dbt/task/deps.py +++ b/core/dbt/task/deps.py @@ -1,555 +1,19 @@ -import abc -import hashlib -import os -import shutil -import tempfile -from dataclasses import dataclass, field -from typing import ( - Union, Dict, Optional, List, Type, Iterator, NoReturn, Generic, TypeVar, -) - import dbt.utils import dbt.deprecations import dbt.exceptions -from dbt import semver -from dbt.ui import printer + +from dbt.deps.base import downloads_directory +from dbt.deps.resolver import resolve_packages from dbt.logger import GLOBAL_LOGGER as logger -from dbt.clients import git, registry, system -from dbt.contracts.project import ProjectPackageMetadata, \ - RegistryPackageMetadata, \ - LocalPackage as LocalPackageContract, \ - GitPackage as GitPackageContract, \ - RegistryPackage as RegistryPackageContract -from dbt.exceptions import raise_dependency_error, package_version_not_found, \ - VersionsNotCompatibleException, DependencyException +from dbt.clients import system from dbt.task.base import ProjectOnlyTask -DOWNLOADS_PATH = None -REMOVE_DOWNLOADS = False -PIN_PACKAGE_URL = 'https://docs.getdbt.com/docs/package-management#section-specifying-package-versions' # noqa - - -def _initialize_downloads(): - global DOWNLOADS_PATH, REMOVE_DOWNLOADS - # the user might have set an environment variable. Set it to None, and do - # not remove it when finished. - if DOWNLOADS_PATH is None: - DOWNLOADS_PATH = os.getenv('DBT_DOWNLOADS_DIR') - REMOVE_DOWNLOADS = False - # if we are making a per-run temp directory, remove it at the end of - # successful runs - if DOWNLOADS_PATH is None: - DOWNLOADS_PATH = tempfile.mkdtemp(prefix='dbt-downloads-') - REMOVE_DOWNLOADS = True - - system.make_directory(DOWNLOADS_PATH) - logger.debug("Set downloads directory='{}'".format(DOWNLOADS_PATH)) - - -PackageContract = Union[LocalPackageContract, GitPackageContract, - RegistryPackageContract] - - -def md5sum(s: str): - return hashlib.md5(s.encode('latin-1')).hexdigest() - - -PackageContractType = TypeVar('PackageContractType', bound=PackageContract) - - -class BasePackage(metaclass=abc.ABCMeta): - @abc.abstractproperty - def name(self) -> str: - raise NotImplementedError - - def all_names(self) -> List[str]: - return [self.name] - - @abc.abstractmethod - def source_type(self) -> str: - raise NotImplementedError - - -class LocalPackageMixin: - def __init__(self, local: str) -> None: - super().__init__() - self.local = local - - @property - def name(self): - return self.local - - def source_type(self): - return 'local' - - -class GitPackageMixin: - def __init__(self, git: str) -> None: - super().__init__() - self.git = git - - @property - def name(self): - return self.git - - def source_type(self) -> str: - return 'git' - - -class RegistryPackageMixin: - def __init__(self, package: str) -> None: - super().__init__() - self.package = package - - @property - def name(self): - return self.package - - def source_type(self) -> str: - return 'hub' - - -class PinnedPackage(BasePackage): - def __init__(self) -> None: - if hasattr(self, '_cached_metadata'): - raise ValueError('already here') - self._cached_metadata: Optional[ProjectPackageMetadata] = None - - def __str__(self) -> str: - version = self.get_version() - if not version: - return self.name - - return '{}@{}'.format(self.name, version) - - @abc.abstractmethod - def get_version(self) -> Optional[str]: - raise NotImplementedError - - @abc.abstractmethod - def _fetch_metadata(self, project): - raise NotImplementedError - - @abc.abstractmethod - def install(self, project): - raise NotImplementedError - - @abc.abstractmethod - def nice_version_name(self): - raise NotImplementedError - - def fetch_metadata(self, project): - if not self._cached_metadata: - self._cached_metadata = self._fetch_metadata(project) - return self._cached_metadata - - def get_project_name(self, project): - metadata = self.fetch_metadata(project) - return metadata.name - - def get_installation_path(self, project): - dest_dirname = self.get_project_name(project) - return os.path.join(project.modules_path, dest_dirname) - - -class LocalPinnedPackage(LocalPackageMixin, PinnedPackage): - def __init__(self, local: str) -> None: - super().__init__(local) - - def get_version(self): - return None - - def nice_version_name(self): - return ''.format(self.local) - - def resolve_path(self, project): - return system.resolve_path_from_base( - self.local, - project.project_root, - ) - - def _fetch_metadata(self, project): - loaded = project.from_project_root(self.resolve_path(project), {}) - return ProjectPackageMetadata.from_project(loaded) - - def install(self, project): - src_path = self.resolve_path(project) - dest_path = self.get_installation_path(project) - - can_create_symlink = system.supports_symlinks() - - if system.path_exists(dest_path): - if not system.path_is_symlink(dest_path): - system.rmdir(dest_path) - else: - system.remove_file(dest_path) - - if can_create_symlink: - logger.debug(' Creating symlink to local dependency.') - system.make_symlink(src_path, dest_path) - - else: - logger.debug(' Symlinks are not available on this ' - 'OS, copying dependency.') - shutil.copytree(src_path, dest_path) - - -class GitPinnedPackage(GitPackageMixin, PinnedPackage): - def __init__( - self, git: str, revision: str, warn_unpinned: bool = True - ) -> None: - super().__init__(git) - self.revision = revision - self.warn_unpinned = warn_unpinned - self._checkout_name = md5sum(self.git) - - def get_version(self): - return self.revision - - def nice_version_name(self): - return 'revision {}'.format(self.revision) - - def _checkout(self): - """Performs a shallow clone of the repository into the downloads - directory. This function can be called repeatedly. If the project has - already been checked out at this version, it will be a no-op. Returns - the path to the checked out directory.""" - try: - dir_ = git.clone_and_checkout( - self.git, DOWNLOADS_PATH, branch=self.revision, - dirname=self._checkout_name - ) - except dbt.exceptions.ExecutableError as exc: - if exc.cmd and exc.cmd[0] == 'git': - logger.error( - 'Make sure git is installed on your machine. More ' - 'information: ' - 'https://docs.getdbt.com/docs/package-management' - ) - raise - return os.path.join(DOWNLOADS_PATH, dir_) - - def _fetch_metadata(self, project) -> ProjectPackageMetadata: - path = self._checkout() - if self.revision == 'master' and self.warn_unpinned: - dbt.exceptions.warn_or_error( - 'The git package "{}" is not pinned.\n\tThis can introduce ' - 'breaking changes into your project without warning!\n\nSee {}' - .format(self.git, PIN_PACKAGE_URL), - log_fmt=printer.yellow('WARNING: {}') - ) - loaded = project.from_project_root(path, {}) - return ProjectPackageMetadata.from_project(loaded) - - def install(self, project): - dest_path = self.get_installation_path(project) - if os.path.exists(dest_path): - if system.path_is_symlink(dest_path): - system.remove_file(dest_path) - else: - system.rmdir(dest_path) - - system.move(self._checkout(), dest_path) - - -class RegistryPinnedPackage(RegistryPackageMixin, PinnedPackage): - def __init__(self, package: str, version: str) -> None: - super().__init__(package) - self.version = version - - @property - def name(self): - return self.package - - def source_type(self): - return 'hub' - - def get_version(self): - return self.version - - def nice_version_name(self): - return 'version {}'.format(self.version) - - def _fetch_metadata(self, project) -> RegistryPackageMetadata: - dct = registry.package_version(self.package, self.version) - return RegistryPackageMetadata.from_dict(dct) - - def install(self, project): - metadata = self.fetch_metadata(project) - - tar_name = '{}.{}.tar.gz'.format(self.package, self.version) - tar_path = os.path.realpath(os.path.join(DOWNLOADS_PATH, tar_name)) - system.make_directory(os.path.dirname(tar_path)) - - download_url = metadata.downloads.tarball - system.download(download_url, tar_path) - deps_path = project.modules_path - package_name = self.get_project_name(project) - system.untar_package(tar_path, deps_path, package_name) - - -SomePinned = TypeVar('SomePinned', bound=PinnedPackage) -SomeUnpinned = TypeVar('SomeUnpinned', bound='UnpinnedPackage') - - -class UnpinnedPackage(Generic[SomePinned], BasePackage): - @abc.abstractclassmethod - def from_contract(cls, contract): - raise NotImplementedError - - @abc.abstractmethod - def incorporate(self: SomeUnpinned, other: SomeUnpinned) -> SomeUnpinned: - raise NotImplementedError - - @abc.abstractmethod - def resolved(self) -> SomePinned: - raise NotImplementedError - - -class LocalUnpinnedPackage( - LocalPackageMixin, UnpinnedPackage[LocalPinnedPackage] -): - @classmethod - def from_contract( - cls, contract: LocalPackageContract - ) -> 'LocalUnpinnedPackage': - return cls(local=contract.local) - - def incorporate( - self, other: 'LocalUnpinnedPackage' - ) -> 'LocalUnpinnedPackage': - return LocalUnpinnedPackage(local=self.local) - - def resolved(self) -> LocalPinnedPackage: - return LocalPinnedPackage(local=self.local) - - -class GitUnpinnedPackage(GitPackageMixin, UnpinnedPackage[GitPinnedPackage]): - def __init__( - self, git: str, revisions: List[str], warn_unpinned: bool = True - ) -> None: - super().__init__(git) - self.revisions = revisions - self.warn_unpinned = warn_unpinned - - @classmethod - def from_contract( - cls, contract: GitPackageContract - ) -> 'GitUnpinnedPackage': - revisions = [contract.revision] if contract.revision else [] - - # we want to map None -> True - warn_unpinned = contract.warn_unpinned is not False - return cls(git=contract.git, revisions=revisions, - warn_unpinned=warn_unpinned) - - def all_names(self) -> List[str]: - if self.git.endswith('.git'): - other = self.git[:-4] - else: - other = self.git + '.git' - return [self.git, other] - - def incorporate( - self, other: 'GitUnpinnedPackage' - ) -> 'GitUnpinnedPackage': - warn_unpinned = self.warn_unpinned and other.warn_unpinned - - return GitUnpinnedPackage( - git=self.git, - revisions=self.revisions + other.revisions, - warn_unpinned=warn_unpinned, - ) - - def resolved(self) -> GitPinnedPackage: - requested = set(self.revisions) - if len(requested) == 0: - requested = {'master'} - elif len(requested) > 1: - dbt.exceptions.raise_dependency_error( - 'git dependencies should contain exactly one version. ' - '{} contains: {}'.format(self.git, requested)) - - return GitPinnedPackage( - git=self.git, revision=requested.pop(), - warn_unpinned=self.warn_unpinned - ) - - -class RegistryUnpinnedPackage( - RegistryPackageMixin, UnpinnedPackage[RegistryPinnedPackage] -): - def __init__( - self, package: str, versions: List[semver.VersionSpecifier] - ) -> None: - super().__init__(package) - self.versions = versions - - def _check_in_index(self): - index = registry.index_cached() - if self.package not in index: - dbt.exceptions.package_not_found(self.package) - - @classmethod - def from_contract( - cls, contract: RegistryPackageContract - ) -> 'RegistryUnpinnedPackage': - raw_version = contract.version - if isinstance(raw_version, str): - raw_version = [raw_version] - - versions = [ - semver.VersionSpecifier.from_version_string(v) - for v in raw_version - ] - return cls(package=contract.package, versions=versions) - - def incorporate( - self, other: 'RegistryUnpinnedPackage' - ) -> 'RegistryUnpinnedPackage': - return RegistryUnpinnedPackage( - package=self.package, - versions=self.versions + other.versions, - ) - - def resolved(self) -> RegistryPinnedPackage: - self._check_in_index() - try: - range_ = semver.reduce_versions(*self.versions) - except VersionsNotCompatibleException as e: - new_msg = ('Version error for package {}: {}' - .format(self.name, e)) - raise DependencyException(new_msg) from e - - available = registry.get_available_versions(self.package) - - # for now, pick a version and then recurse. later on, - # we'll probably want to traverse multiple options - # so we can match packages. not going to make a difference - # right now. - target = semver.resolve_to_specific_version(range_, available) - if not target: - package_version_not_found(self.package, range_, available) - return RegistryPinnedPackage(package=self.package, version=target) - - -@dataclass -class PackageListing: - packages: Dict[str, UnpinnedPackage] = field(default_factory=dict) - - def __len__(self): - return len(self.packages) - - def __bool__(self): - return bool(self.packages) - - def _pick_key(self, key: BasePackage) -> str: - for name in key.all_names(): - if name in self.packages: - return name - return key.name - - def __contains__(self, key: BasePackage): - for name in key.all_names(): - if name in self.packages: - return True - - def __getitem__(self, key: BasePackage): - key_str: str = self._pick_key(key) - return self.packages[key_str] - - def __setitem__(self, key: BasePackage, value): - key_str: str = self._pick_key(key) - self.packages[key_str] = value - - def _mismatched_types( - self, old: UnpinnedPackage, new: UnpinnedPackage - ) -> NoReturn: - raise_dependency_error( - f'Cannot incorporate {new} ({new.__class__.__name__}) in {old} ' - f'({old.__class__.__name__}): mismatched types' - ) - - def incorporate(self, package: UnpinnedPackage): - key: str = self._pick_key(package) - if key in self.packages: - existing: UnpinnedPackage = self.packages[key] - if not isinstance(existing, type(package)): - self._mismatched_types(existing, package) - self.packages[key] = existing.incorporate(package) - else: - self.packages[key] = package - - def update_from(self, src: List[PackageContract]) -> None: - pkg: UnpinnedPackage - for contract in src: - if isinstance(contract, LocalPackageContract): - pkg = LocalUnpinnedPackage.from_contract(contract) - elif isinstance(contract, GitPackageContract): - pkg = GitUnpinnedPackage.from_contract(contract) - elif isinstance(contract, RegistryPackageContract): - pkg = RegistryUnpinnedPackage.from_contract(contract) - else: - raise dbt.exceptions.InternalException( - 'Invalid package type {}'.format(type(contract)) - ) - self.incorporate(pkg) - - @classmethod - def from_contracts( - cls: Type['PackageListing'], src: List[PackageContract] - ) -> 'PackageListing': - self = cls({}) - self.update_from(src) - return self - - def resolved(self) -> List[PinnedPackage]: - return [p.resolved() for p in self.packages.values()] - - def __iter__(self) -> Iterator[UnpinnedPackage]: - return iter(self.packages.values()) - - -def resolve_packages( - packages: List[PackageContract], config -) -> List[PinnedPackage]: - pending = PackageListing.from_contracts(packages) - final = PackageListing() - - while pending: - next_pending = PackageListing() - # resolve the dependency in question - for package in pending: - final.incorporate(package) - target = final[package].resolved().fetch_metadata(config) - next_pending.update_from(target.packages) - pending = next_pending - return final.resolved() - class DepsTask(ProjectOnlyTask): def __init__(self, args, config=None): super().__init__(args=args, config=config) - self._downloads_path = None - - @property - def downloads_path(self): - if self._downloads_path is None: - self._downloads_path = tempfile.mkdtemp(prefix='dbt-downloads') - return self._downloads_path - - def _check_for_duplicate_project_names(self, final_deps): - seen = set() - for package in final_deps: - project_name = package.get_project_name(self.config) - if project_name in seen: - dbt.exceptions.raise_dependency_error( - 'Found duplicate project {}. This occurs when a dependency' - ' has the same project name as some other dependency.' - .format(project_name)) - seen.add(project_name) def track_package_install(self, package_name, source_type, version): version = 'local' if source_type == 'local' else version @@ -565,26 +29,21 @@ def track_package_install(self, package_name, source_type, version): def run(self): system.make_directory(self.config.modules_path) - _initialize_downloads() - packages = self.config.packages.packages if not packages: logger.info('Warning: No packages were found in packages.yml') return - final_deps = resolve_packages(packages, self.config) - - self._check_for_duplicate_project_names(final_deps) - - for package in final_deps: - logger.info('Installing {}', package) - package.install(self.config) - logger.info(' Installed from {}\n', package.nice_version_name()) + with downloads_directory(): + final_deps = resolve_packages(packages, self.config) - self.track_package_install( - package_name=package.name, - source_type=package.source_type(), - version=package.get_version()) + for package in final_deps: + logger.info('Installing {}', package) + package.install(self.config) + logger.info(' Installed from {}\n', + package.nice_version_name()) - if REMOVE_DOWNLOADS: - system.rmtree(DOWNLOADS_PATH) + self.track_package_install( + package_name=package.name, + source_type=package.source_type(), + version=package.get_version()) diff --git a/core/dbt/task/rpc/base.py b/core/dbt/task/rpc/base.py index 7187724d53b..7f8e425d264 100644 --- a/core/dbt/task/rpc/base.py +++ b/core/dbt/task/rpc/base.py @@ -1,15 +1,19 @@ from dbt.contracts.rpc import RemoteExecutionResult from dbt.task.runnable import GraphRunnableTask -from dbt.rpc.method import RemoteMethod, Parameters +from dbt.rpc.method import RemoteManifestMethod, Parameters class RPCTask( GraphRunnableTask, - RemoteMethod[Parameters, RemoteExecutionResult] + RemoteManifestMethod[Parameters, RemoteExecutionResult] ): def __init__(self, args, config, manifest): super().__init__(args, config) - RemoteMethod.__init__(self, args, config, manifest) + RemoteManifestMethod.__init__(self, args, config, manifest) + + def load_manifest(self): + # we started out with a manifest! + pass def get_result( self, results, elapsed_time, generated_at diff --git a/core/dbt/task/rpc/cli.py b/core/dbt/task/rpc/cli.py new file mode 100644 index 00000000000..d7d176be4bd --- /dev/null +++ b/core/dbt/task/rpc/cli.py @@ -0,0 +1,90 @@ +import abc +import shlex +from typing import Type, Optional + + +from dbt.contracts.rpc import RPCCliParameters + +from dbt.rpc.method import ( + RemoteMethod, + RemoteManifestMethod, + Parameters, + Result, +) +from dbt.exceptions import InternalException + +from .base import RPCTask + + +class HasCLI(RemoteMethod[Parameters, Result]): + @classmethod + def has_cli_parameters(cls): + return True + + @abc.abstractmethod + def handle_request(self) -> Result: + pass + + +class RemoteRPCCli(RPCTask[RPCCliParameters]): + METHOD_NAME = 'cli_args' + + def __init__(self, args, config, manifest): + super().__init__(args, config, manifest) + self.task_type: Optional[Type[RemoteMethod]] = None + self.real_task: Optional[RemoteMethod] = None + + def set_config(self, config): + super().set_config(config) + if issubclass(self.task_type, RemoteManifestMethod): + self.real_task = self.task_type( + self.args, self.config, self.manifest + ) + else: + self.real_task = self.task_type( + self.args, self.config + ) + + def set_args(self, params: RPCCliParameters) -> None: + # more import cycles :( + from dbt.main import parse_args, RPCArgumentParser + split = shlex.split(params.cli) + self.args = parse_args(split, RPCArgumentParser) + self.task_type = self.get_rpc_task_cls() + + def get_flags(self): + return self.task_type.get_flags(self) + + def get_rpc_task_cls(self) -> Type[HasCLI]: + # This is obnoxious, but we don't have actual access to the TaskManager + # so instead we get to dig through all the subclasses of RPCTask + # (recursively!) looking for a matching METHOD_NAME + candidate: Type[HasCLI] + for candidate in HasCLI.recursive_subclasses(): + if candidate.METHOD_NAME == self.args.rpc_method: + return candidate + # this shouldn't happen + raise InternalException( + 'No matching handler found for rpc method {} (which={})' + .format(self.args.rpc_method, self.args.which) + ) + + def load_manifest(self): + # we started out with a manifest! + pass + + def handle_request(self) -> Result: + if self.real_task is None: + raise InternalException( + 'CLI task is in a bad state: handle_request called with no ' + 'real_task set!' + ) + # we parsed args from the cli, so we're set on that front + return self.real_task.handle_request() + + def interpret_results(self, results): + if self.real_task is None: + # I don't know what happened, but it was surely some flavor of + # failure + return False + return self.real_task.interpret_results(results) diff --git a/core/dbt/task/rpc/deps.py b/core/dbt/task/rpc/deps.py new file mode 100644 index 00000000000..b6506434a45 --- /dev/null +++ b/core/dbt/task/rpc/deps.py @@ -0,0 +1,36 @@ +import os +import shutil + +from dbt.contracts.rpc import ( + RPCNoParameters, RemoteEmptyResult, RemoteMethodFlags, +) +from dbt.rpc.method import RemoteMethod +from dbt.task.deps import DepsTask + + +def _clean_deps(config): + modules_dir = os.path.join(config.project_root, config.modules_path) + if os.path.exists(modules_dir): + shutil.rmtree(modules_dir) + os.makedirs(modules_dir) + + +class RemoteDepsTask( + RemoteMethod[RPCNoParameters, RemoteEmptyResult], + DepsTask, +): + METHOD_NAME = 'deps' + + def get_flags(self) -> RemoteMethodFlags: + return ( + RemoteMethodFlags.RequiresConfigReloadBefore | + RemoteMethodFlags.RequiresManifestReloadAfter + ) + + def set_args(self, params: RPCNoParameters): + pass + + def handle_request(self) -> RemoteEmptyResult: + _clean_deps(self.config) + self.run() + return RemoteEmptyResult([]) diff --git a/core/dbt/task/rpc/project_commands.py b/core/dbt/task/rpc/project_commands.py index b8ad3d688d4..93793c7f2ed 100644 --- a/core/dbt/task/rpc/project_commands.py +++ b/core/dbt/task/rpc/project_commands.py @@ -1,10 +1,8 @@ from datetime import datetime -import shlex -from typing import Type, List, Optional, Union +from typing import List, Optional, Union from dbt.contracts.rpc import ( - RPCCliParameters, RPCCompileParameters, RPCDocsGenerateParameters, RPCSeedParameters, @@ -12,17 +10,23 @@ RemoteCatalogResults, RemoteExecutionResult, ) -from dbt.exceptions import InternalException +from dbt.rpc.method import ( + Parameters, +) from dbt.task.compile import CompileTask from dbt.task.generate import GenerateTask from dbt.task.run import RunTask from dbt.task.seed import SeedTask from dbt.task.test import TestTask -from .base import RPCTask, Parameters +from .base import RPCTask +from .cli import HasCLI -class RPCCommandTask(RPCTask[Parameters]): +class RPCCommandTask( + RPCTask[Parameters], + HasCLI[Parameters, RemoteExecutionResult], +): @staticmethod def _listify( value: Optional[Union[str, List[str]]] @@ -34,10 +38,6 @@ def _listify( else: return value - def load_manifest(self): - # we started out with a manifest! - pass - def handle_request(self) -> RemoteExecutionResult: return self.run() @@ -97,33 +97,3 @@ def get_catalog_results( _compile_results=compile_results, logs=[], ) - - -class RemoteRPCParameters(RPCCommandTask[RPCCliParameters]): - METHOD_NAME = 'cli_args' - - def set_args(self, params: RPCCliParameters) -> None: - # more import cycles :( - from dbt.main import parse_args, RPCArgumentParser - split = shlex.split(params.cli) - self.args = parse_args(split, RPCArgumentParser) - - def get_rpc_task_cls(self) -> Type[RPCCommandTask]: - # This is obnoxious, but we don't have actual access to the TaskManager - # so instead we get to dig through all the subclasses of RPCTask - # (recursively!) looking for a matching METHOD_NAME - candidate: Type[RPCCommandTask] - for candidate in RPCCommandTask.recursive_subclasses(): - if candidate.METHOD_NAME == self.args.rpc_method: - return candidate - # this shouldn't happen - raise InternalException( - 'No matching handler found for rpc method {} (which={})' - .format(self.args.rpc_method, self.args.which) - ) - - def handle_request(self) -> RemoteExecutionResult: - cls = self.get_rpc_task_cls() - # we parsed args from the cli, so we're set on that front - task = cls(self.args, self.config, self.manifest) - return task.handle_request() diff --git a/core/dbt/task/rpc/server.py b/core/dbt/task/rpc/server.py index 026b8b2283c..8a3a963f43e 100644 --- a/core/dbt/task/rpc/server.py +++ b/core/dbt/task/rpc/server.py @@ -1,12 +1,12 @@ # import these so we can find them from . import sql_commands # noqa from . import project_commands # noqa -from .base import RPCTask +from . import deps # noqa import json import os import signal -import threading from contextlib import contextmanager +from typing import Iterator, Optional from werkzeug.middleware.dispatcher import DispatcherMiddleware from werkzeug.wrappers import Request, Response @@ -16,18 +16,14 @@ from dbt.exceptions import RuntimeException from dbt.logger import ( GLOBAL_LOGGER as logger, - list_handler, log_manager, ) -from dbt.task.base import ConfiguredTask -from dbt.utils import ForgivingJSONEncoder, env_set_truthy +from dbt.rpc.logger import ServerContext, HTTPRequest, RPCResponse +from dbt.rpc.method import TaskList from dbt.rpc.response_manager import ResponseManager from dbt.rpc.task_manager import TaskManager -from dbt.rpc.logger import ServerContext, HTTPRequest, RPCResponse -from dbt.perf_utils import get_full_manifest - - -SINGLE_THREADED_WEBSERVER = env_set_truthy('DBT_SINGLE_THREADED_WEBSERVER') +from dbt.task.base import ConfiguredTask +from dbt.utils import ForgivingJSONEncoder # SIG_DFL ends up killing the process if multiple build up, but SIG_IGN just @@ -35,24 +31,8 @@ SIG_IGN = signal.SIG_IGN -def reload_manager(task_manager, tasks): - logs = [] - try: - with list_handler(logs): - manifest = get_full_manifest(task_manager.config) - - for cls in tasks: - task_manager.add_task_handler(cls, manifest) - except Exception as exc: - logs = [r.to_dict() for r in logs] - task_manager.set_compile_exception(exc, logs=logs) - else: - logs = [r.to_dict() for r in logs] - task_manager.set_ready(logs=logs) - - @contextmanager -def signhup_replace(): +def signhup_replace() -> Iterator[bool]: """A context manager. Replace the current sighup handler with SIG_IGN on entering, and (if the current handler was not SIG_IGN) replace it on leaving. This is meant to be used inside a sighup handler itself to @@ -93,16 +73,17 @@ def signhup_replace(): class RPCServerTask(ConfiguredTask): DEFAULT_LOG_FORMAT = 'json' - def __init__(self, args, config, tasks=None): + def __init__(self, args, config, tasks: Optional[TaskList] = None): if os.name == 'nt': raise RuntimeException( 'The dbt RPC server is not supported on windows' ) super().__init__(args, config) - self._tasks = tasks or self._default_tasks() - self.task_manager = TaskManager(self.args, self.config) - self._reloader = None - self._reload_task_manager() + self.task_manager = TaskManager( + self.args, self.config, TaskList(tasks) + ) + self.task_manager.reload_non_manifest_tasks() + self.task_manager.reload_manifest_tasks() signal.signal(signal.SIGHUP, self._sighup_handler) @classmethod @@ -113,39 +94,16 @@ def pre_init_hook(cls, args): else: log_manager.format_json() - def _reload_task_manager(self): - """This function can only be running once at a time, as it runs in the - signal handler we replace - """ - # mark the task manager invalid for task running - self.task_manager.set_compiling() - for task in self._tasks: - self.task_manager.reserve_handler(task) - # compile in a thread that will fix up the tag manager when it's done - reloader = threading.Thread( - target=reload_manager, - args=(self.task_manager, self._tasks), - ) - reloader.start() - # only assign to _reloader here, to avoid calling join() before start() - self._reloader = reloader - def _sighup_handler(self, signum, frame): with signhup_replace() as run_task_manger: if not run_task_manger: # a sighup handler is already active. return - if self._reloader is not None and self._reloader.is_alive(): - # a reloader is already active. - return - self._reload_task_manager() - - @staticmethod - def _default_tasks(): - return RPCTask.recursive_subclasses(named_only=True) + self.task_manager.reload_config() + self.task_manager.reload_manifest_tasks() def single_threaded(self): - return SINGLE_THREADED_WEBSERVER or self.args.single_threaded + return self.task_manager.single_threaded() def run_forever(self): host = self.args.host @@ -181,7 +139,12 @@ def run_forever(self): # metadata+state in a multiprocessing.Manager, adds polling the # manager to the request task handler and in general gets messy # fast. - run_simple(host, port, app, threaded=not self.single_threaded) + run_simple( + host, + port, + app, + threaded=not self.task_manager.single_threaded(), + ) def run(self): with ServerContext().applicationbound(): diff --git a/core/dbt/task/rpc/sql_commands.py b/core/dbt/task/rpc/sql_commands.py index 855a9d705d4..51580a3b5c4 100644 --- a/core/dbt/task/rpc/sql_commands.py +++ b/core/dbt/task/rpc/sql_commands.py @@ -181,6 +181,9 @@ def handle_request(self) -> RemoteExecutionResult: generated_at=ended, ) + def interpret_results(self, results): + return True + class RemoteCompileTask(RemoteRunSQLTask, CompileTask): METHOD_NAME = 'compile_sql' diff --git a/test/integration/048_rpc_test/deps_models/main.sql b/test/integration/048_rpc_test/deps_models/main.sql new file mode 100644 index 00000000000..8f554fcc37a --- /dev/null +++ b/test/integration/048_rpc_test/deps_models/main.sql @@ -0,0 +1,3 @@ +{{ dbt_utils.log_info('blah') }} + +select 1 as id diff --git a/test/integration/048_rpc_test/test_rpc.py b/test/integration/048_rpc_test/test_rpc.py index 127a5c91a44..2f01cccbe0a 100644 --- a/test/integration/048_rpc_test/test_rpc.py +++ b/test/integration/048_rpc_test/test_rpc.py @@ -1,6 +1,7 @@ import json import os import random +import shutil import signal import socket import time @@ -172,27 +173,41 @@ def build_query( def url(self): return 'http://localhost:{}/jsonrpc'.format(self._server.port) - def poll_for_result(self, request_token, request_id=1, timeout=60): + def poll_for_result(self, request_token, request_id=1, timeout=60, status='success', logs=None): start = time.time() + kwargs = { + 'request_token': request_token, + } + if logs is not None: + kwargs['logs'] = logs + while True: time.sleep(0.5) - response = self.query('poll', request_token=request_token, _test_request_id=request_id) + response = self.query('poll', _test_request_id=request_id, **kwargs) response_json = response.json() if 'error' in response_json: return response result = self.assertIsResult(response_json, request_id) self.assertIn('status', result) - if result['status'] == 'success': + if result['status'] == status: return response if timeout is not None: - self.assertGreater(timeout, (time.time() - start)) - - - def async_query(self, _method, _sql=None, _test_request_id=1, macros=None, **kwargs): + delta = (time.time() - start) + self.assertGreater( + timeout, delta, + 'At time {}, never saw {}.\nLast response: {}' + .format(delta, status, result) + ) + + def async_query(self, _method, _sql=None, _test_request_id=1, _poll_timeout=60, macros=None, **kwargs): response = self.query(_method, _sql, _test_request_id, macros, **kwargs).json() result = self.assertIsResult(response, _test_request_id) self.assertIn('request_token', result) - return self.poll_for_result(result['request_token'], _test_request_id) + return self.poll_for_result( + result['request_token'], + request_id=_test_request_id, + timeout=_poll_timeout, + ) def query(self, _method, _sql=None, _test_request_id=1, macros=None, **kwargs): built = self.build_query(_method, kwargs, _sql, _test_request_id, macros) @@ -305,15 +320,13 @@ def kill_and_assert(self, request_token, request_id): poll_id = 90891 - poll_response = self.poll_for_result(request_token, poll_id).json() - error = self.assertIsErrorWithCode(poll_response, 10009, poll_id) - self.assertEqual(error['message'], 'RPC process killed') - self.assertIn('data', error) - error_data = error['data'] - self.assertEqual(error_data['signum'], 2) - self.assertEqual(error_data['message'], 'RPC process killed by signal 2') - self.assertIn('logs', error_data) - return error_data + poll_response = self.poll_for_result( + request_token, request_id=poll_id, status='killed', logs=True + ).json() + + result = self.assertIsResult(poll_response, id_=poll_id) + self.assertIn('logs', result) + return result def get_sleep_query(self, duration=15, request_id=90890): sleep_query = self.query( @@ -327,23 +340,27 @@ def get_sleep_query(self, duration=15, request_id=90890): request_token = result['request_token'] return request_token, request_id - def wait_for_running(self, timeout=25, raise_on_timeout=True): + def wait_for_state( + self, state, timestamp, timeout=25, raise_on_timeout=True + ): started = time.time() time.sleep(0.5) elapsed = time.time() - started while elapsed < timeout: status = self.assertIsResult(self.query('status').json()) - if status['status'] == 'running': + self.assertTrue(status['timestamp'] >= timestamp) + if status['timestamp'] != timestamp and status['status'] == state: return status time.sleep(0.5) elapsed = time.time() - started status = self.assertIsResult(self.query('status').json()) + self.assertTrue(status['timestamp'] >= timestamp) if raise_on_timeout: self.assertEqual( status['status'], - 'ready', + state, f'exceeded max time of {timeout}: {elapsed} seconds elapsed' ) return status @@ -353,9 +370,11 @@ def run_command_with_id(self, cmd, id_): def make_many_requests(self, num_requests): stored = [] - for idx in range(num_requests): - response = self.query('run_sql', 'select 1 as id', name='run').json() - result = self.assertIsResult(response) + for idx in range(1, num_requests+1): + response = self.query( + 'run_sql', 'select 1 as id', name='run', _test_request_id=idx + ).json() + result = self.assertIsResult(response, id_=idx) self.assertIn('request_token', result) token = result['request_token'] self.poll_for_result(token) @@ -654,7 +673,7 @@ def test_ps_kill_postgres(self): self.assertEqual(rowdict[0]['tags'], task_tags) self.assertEqual(rowdict[1]['request_id'], request_id) self.assertEqual(rowdict[1]['method'], 'run_sql') - self.assertEqual(rowdict[1]['state'], 'error') + self.assertEqual(rowdict[1]['state'], 'killed') self.assertIsNone(rowdict[1]['timeout']) self.assertGreater(rowdict[1]['elapsed'], 0) self.assertIsNone(rowdict[1]['tags']) @@ -669,8 +688,8 @@ def test_ps_kill_longwait_postgres(self): # we cancel the in-progress sleep query. time.sleep(3) - error_data = self.kill_and_assert(request_token, request_id) - self.assertTrue(len(error_data['logs']) > 0) + result_data = self.kill_and_assert(request_token, request_id) + self.assertTrue(len(result_data['logs']) > 0) @use_profile('postgres') def test_invalid_requests_postgres(self): @@ -807,7 +826,6 @@ def test_seed_project_cli_postgres(self): @use_profile('postgres') def test_compile_project_postgres(self): - self.run_dbt_with_vars(['seed']) result = self.async_query('compile').json() self.assertHasResults( @@ -827,7 +845,7 @@ def test_compile_project_postgres(self): @use_profile('postgres') def test_compile_project_cli_postgres(self): - self.run_dbt_with_vars(['seed']) + self.run_dbt_with_vars(['compile']) result = self.async_query('cli_args', cli='compile').json() self.assertHasResults( result, @@ -846,21 +864,18 @@ def test_compile_project_cli_postgres(self): @use_profile('postgres') def test_run_project_postgres(self): - self.run_dbt_with_vars(['seed']) result = self.async_query('run').json() self.assertHasResults(result, {'descendant_model', 'multi_source_model', 'nonsource_descendant'}) self.assertTablesEqual('multi_source_model', 'expected_multi_source') @use_profile('postgres') def test_run_project_cli_postgres(self): - self.run_dbt_with_vars(['seed']) result = self.async_query('cli_args', cli='run').json() self.assertHasResults(result, {'descendant_model', 'multi_source_model', 'nonsource_descendant'}) self.assertTablesEqual('multi_source_model', 'expected_multi_source') @use_profile('postgres') def test_test_project_postgres(self): - self.run_dbt_with_vars(['seed']) self.run_dbt_with_vars(['run']) data = self.async_query('test').json() result = self.assertIsResult(data) @@ -869,7 +884,6 @@ def test_test_project_postgres(self): @use_profile('postgres') def test_test_project_cli_postgres(self): - self.run_dbt_with_vars(['seed']) self.run_dbt_with_vars(['run']) data = self.async_query('cli_args', cli='test').json() result = self.assertIsResult(data) @@ -891,7 +905,6 @@ def assertHasDocsGenerated(self, result, expected): nodes = dct['nodes'] self.assertEqual(set(nodes), expected) - def assertCatalogExists(self): self.assertTrue(os.path.exists('target/catalog.json')) with open('target/catalog.json') as fp: @@ -914,20 +927,17 @@ def _correct_docs_generate_result(self, result): self.assertCatalogExists() self.assertManifestExists(17) - @use_profile('postgres') def test_docs_generate_postgres(self): - self.run_dbt_with_vars(['seed']) self.run_dbt_with_vars(['run']) self.assertFalse(os.path.exists('target/catalog.json')) if os.path.exists('target/manifest.json'): os.remove('target/manifest.json') - result = self.async_query('cli_args', cli='docs generate').json() + result = self.async_query('docs.generate').json() self._correct_docs_generate_result(result) @use_profile('postgres') def test_docs_generate_postgres_cli(self): - self.run_dbt_with_vars(['seed']) self.run_dbt_with_vars(['run']) self.assertFalse(os.path.exists('target/catalog.json')) if os.path.exists('target/manifest.json'): @@ -935,6 +945,15 @@ def test_docs_generate_postgres_cli(self): result = self.async_query('cli_args', cli='docs generate').json() self._correct_docs_generate_result(result) + @use_profile('postgres') + def test_deps_postgres(self): + self.async_query('deps').json() + + @mark.skip(reason='cli_args + deps not supported for now') + @use_profile('postgres') + def test_deps_postgres_cli(self): + self.async_query('cli_args', cli='deps').json() + @mark.flaky(rerun_filter=addr_in_use) class TestRPCTaskManagement(HasRPCServer): @@ -966,7 +985,7 @@ def test_sighup_postgres(self): for _ in range(10): os.kill(status['pid'], signal.SIGHUP) - status = self.wait_for_running() + self.wait_for_state('ready', timestamp=status['timestamp']) # we should still still see our service: self.assertRunning(sleepers) @@ -1038,83 +1057,74 @@ def test_gc_by_id_postgres(self): result = self.assertIsResult(resp) self.assertEqual(len(result['rows']), 0) - @use_profile('postgres') - def test_postgres_gc_change_interval(self): - num_requests = 10 - self.make_many_requests(num_requests) - - # all present - resp = self.query('ps', completed=True, active=True).json() - result = self.assertIsResult(resp) - self.assertEqual(len(result['rows']), num_requests) - resp = self.query('gc', settings=dict(maxsize=1000, reapsize=5, auto_reap_age=0.1)).json() - result = self.assertIsResult(resp) - self.assertEqual(len(result['deleted']), 0) - self.assertEqual(len(result['missing']), 0) - self.assertEqual(len(result['running']), 0) - time.sleep(0.5) - - # all cleared up - test_resp = self.query('ps', completed=True, active=True) - resp = test_resp.json() - result = self.assertIsResult(resp) - self.assertEqual(len(result['rows']), 0) - - resp = self.query('gc', settings=dict(maxsize=2, reapsize=5, auto_reap_age=10000)).json() - result = self.assertIsResult(resp) - self.assertEqual(len(result['deleted']), 0) - self.assertEqual(len(result['missing']), 0) - self.assertEqual(len(result['running']), 0) - - # make more requests - self.make_many_requests(num_requests) - time.sleep(0.5) - # there should be 2 left! - resp = self.query('ps', completed=True, active=True).json() - result = self.assertIsResult(resp) - self.assertEqual(len(result['rows']), 2) - - -class FailedServerProcess(ServerProcess): +class CompletingServerProcess(ServerProcess): def _compare_result(self, result): - return result['result']['status'] == 'error' + return result['result']['status'] in ('error', 'ready') @mark.flaky(rerun_filter=addr_in_use) -class TestRPCServerFailed(HasRPCServer): - ServerProcess = FailedServerProcess +class TestRPCServerDeps(HasRPCServer): + ServerProcess = CompletingServerProcess should_seed = False - @property - def models(self): - return "malformed_models" + def setUp(self): + super().setUp() + if os.path.exists('./dbt_modules'): + shutil.rmtree('./dbt_modules') def tearDown(self): - # prevent an OperationalError where the server closes on us in the - # background + if os.path.exists('./dbt_modules'): + shutil.rmtree('./dbt_modules') self.adapter.cleanup_connections() super().tearDown() - @use_profile('postgres') - def test_postgres_status_error(self): + @property + def packages_config(selF): + return { + 'packages': [ + {'package': 'fishtown-analytics/dbt_utils', 'version': '0.2.1'}, + ] + } + + @property + def models(self): + return "deps_models" + + def _check_start_predeps(self): + self.assertFalse(os.path.exists('./dbt_modules')) status = self.assertIsResult(self.query('status').json()) - self.assertEqual(status['status'], 'error') - self.assertIn('logs', status) - logs = status['logs'] - self.assertTrue(len(logs) > 0) - for key in ('message', 'timestamp', 'levelname', 'level'): - self.assertIn(key, logs[0]) - self.assertIn('pid', status) - self.assertEqual(self._server.pid, status['pid']) - self.assertIn('error', status) - self.assertIn('message', status['error']) - - compile_result = self.query('compile_sql', 'select 1 as id').json() - data = self.assertIsErrorWith( - compile_result, - 10011, - 'RPC server failed to compile project, call the "status" method for compile status', - None) - self.assertIn('message', data) - self.assertIn('Invalid test config', str(data['message'])) + self.assertEqual(status['status'], 'ready') + + self.assertIsError(self.async_query('compile').json()) + if os.path.exists('./dbt_modules'): + self.assertEqual(len(os.listdir('./dbt_modules')), 0) + return status + + def _check_deps_ok(self, status): + os.kill(status['pid'], signal.SIGHUP) + + self.wait_for_state('ready', timestamp=status['timestamp']) + + self.assertTrue(os.path.exists('./dbt_modules')) + self.assertEqual(len(os.listdir('./dbt_modules')), 1) + self.assertIsResult(self.async_query('compile').json()) + + @use_profile('postgres') + def test_deps_compilation_postgres(self): + status = self._check_start_predeps() + + # do a dbt deps, wait for the result + self.assertIsResult(self.async_query('deps', _poll_timeout=120).json()) + + self._check_deps_ok(status) + + @mark.skip(reason='cli_args + deps not supported for now') + @use_profile('postgres') + def test_deps_cli_compilation_postgres(self): + status = self._check_start_predeps() + + # do a dbt deps, wait for the result + self.assertIsResult(self.async_query('cli_args', cli='deps', _poll_timeout=120).json()) + + self._check_deps_ok(status) diff --git a/test/rpc/__init__.py b/test/rpc/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/rpc/fixtures.py b/test/rpc/fixtures.py new file mode 100644 index 00000000000..a58a9f4dfb7 --- /dev/null +++ b/test/rpc/fixtures.py @@ -0,0 +1,533 @@ +import base64 +import json +import os +import pytest +import random +import signal +import socket +import time +from contextlib import contextmanager +from typing import Dict, Any, Optional, Union, List + +import requests +import yaml + +import dbt.flags +from dbt.adapters.factory import get_adapter, register_adapter +from dbt.logger import log_manager +from dbt.main import handle_and_check +from dbt.config import RuntimeConfig + + +def query_url(url, query: Dict[str, Any]): + headers = {'content-type': 'application/json'} + return requests.post(url, headers=headers, data=json.dumps(query)) + + +class NoServerException(Exception): + pass + + +class ServerProcess(dbt.flags.MP_CONTEXT.Process): + def __init__(self, cwd, port, profiles_dir, cli_vars=None, criteria=('ready',)): + self.cwd = cwd + self.port = port + self.criteria = criteria + self.error = None + handle_and_check_args = [ + '--strict', 'rpc', '--log-cache-events', + '--port', str(self.port), + '--profiles-dir', profiles_dir + ] + if cli_vars: + handle_and_check_args.extend(['--vars', cli_vars]) + super().__init__( + target=handle_and_check, + args=(handle_and_check_args,), + name='ServerProcess') + + def run(self): + os.chdir(self.cwd) + log_manager.reset_handlers() + # run server tests in stderr mode + log_manager.stderr_console() + return super().run() + + def can_connect(self): + sock = socket.socket() + try: + sock.connect(('localhost', self.port)) + except socket.error: + return False + sock.close() + return True + + def _compare_result(self, result): + return result['result']['status'] in self.criteria + + def status_ok(self): + result = self.query( + {'method': 'status', 'id': 1, 'jsonrpc': '2.0'} + ).json() + return self._compare_result(result) + + def is_up(self): + if not self.can_connect(): + return False + return self.status_ok() + + def start(self): + super().start() + for _ in range(30): + if self.is_up(): + break + time.sleep(0.5) + if not self.can_connect(): + raise NoServerException('server never appeared!') + status_result = self.query( + {'method': 'status', 'id': 1, 'jsonrpc': '2.0'} + ).json() + if not self._compare_result(status_result): + raise NoServerException( + 'Got invalid status result: {}'.format(status_result) + ) + + @property + def url(self): + return 'http://localhost:{}/jsonrpc'.format(self.port) + + def query(self, query): + headers = {'content-type': 'application/json'} + return requests.post(self.url, headers=headers, data=json.dumps(query)) + + +class Querier: + def __init__(self, server: ServerProcess): + self.server = server + + def build_request_data(self, method, params, request_id): + return { + 'jsonrpc': '2.0', + 'method': method, + 'params': params, + 'id': request_id, + } + + def request(self, method, params=None, request_id=1): + if params is None: + params = {} + + data = self.build_request_data( + method=method, params=params, request_id=request_id + ) + response = self.server.query(data) + assert response.ok, f'invalid response from server: {response.text}' + return response.json() + + def status(self, request_id: int = 1): + return self.request(method='status', request_id=request_id) + + def ps(self, active=True, completed=False, request_id=1): + params = {} + if active is not None: + params['active'] = active + if completed is not None: + params['completed'] = completed + + return self.request(method='ps', params=params, request_id=request_id) + + def kill(self, task_id: str, request_id: int = 1): + params = {'task_id': task_id} + return self.request( + method='kill', params=params, request_id=request_id + ) + + def poll( + self, + request_token: str, + logs: bool = False, + logs_start: int = 0, + request_id: int = 1, + ): + params = { + 'request_token': request_token, + } + if logs is not None: + params['logs'] = logs + if logs_start is not None: + params['logs_start'] = logs_start + return self.request( + method='poll', params=params, request_id=request_id + ) + + def gc( + self, + task_ids: Optional[List[str]] = None, + before: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, + request_id: int = 1, + ): + params = {} + if task_ids is not None: + params['task_ids'] = task_ids + if before is not None: + params['before'] = before + if settings is not None: + params['settings'] = settings + return self.request( + method='gc', params=params, request_id=request_id + ) + + def cli_args(self, cli: str, request_id: int = 1): + return self.request( + method='cli_args', params={'cli': cli}, request_id=request_id + ) + + def deps(self, request_id: int = 1): + return self.request(method='deps', request_id=request_id) + + def compile( + self, + models: Optional[Union[str, List[str]]] = None, + exclude: Optional[Union[str, List[str]]] = None, + request_id: int = 1, + ): + params = {} + if models is not None: + params['models'] = models + if exclude is not None: + params['exclude'] = exclude + return self.request( + method='compile', params=params, request_id=request_id + ) + + def run( + self, + models: Optional[Union[str, List[str]]] = None, + exclude: Optional[Union[str, List[str]]] = None, + request_id: int = 1, + ): + params = {} + if models is not None: + params['models'] = models + if exclude is not None: + params['exclude'] = exclude + return self.request( + method='run', params=params, request_id=request_id + ) + + def seed(self, show: bool = None, request_id: int = 1): + params = {} + if show is not None: + params['show'] = show + return self.request( + method='seed', params=params, request_id=request_id + ) + + def test( + self, + models: Optional[Union[str, List[str]]] = None, + exclude: Optional[Union[str, List[str]]] = None, + data: bool = None, + schema: bool = None, + request_id: int = 1, + ): + params = {} + if models is not None: + params['models'] = models + if exclude is not None: + params['exclude'] = exclude + if data is not None: + params['data'] = data + if schema is not None: + params['schema'] = schema + return self.requuest( + method='test', params=params, request_id=request_id + ) + + def docs_generate(self, compile: bool = None, request_id: int = 1): + params = {} + if compile is not None: + params['compile'] = True + return self.request( + method='docs.generate', params=params, request_id=request_id + ) + + def compile_sql( + self, + sql: str, + name: str = 'test_compile', + macros: Optional[str] = None, + request_id: int = 1, + ): + sql = base64.b64encode(sql.encode('utf-8')).decode('utf-8') + params = { + 'name': name, + 'sql': sql, + 'macros': macros, + } + return self.request( + method='compile_sql', params=params, request_id=request_id + ) + + def run_sql( + self, + sql: str, + name: str = 'test_run', + macros: Optional[str] = None, + request_id: int = 1, + ): + sql = base64.b64encode(sql.encode('utf-8')).decode('utf-8') + params = { + 'name': name, + 'sql': sql, + 'macros': macros, + } + return self.request( + method='run_sql', params=params, request_id=request_id + ) + + def is_result(self, data: Dict[str, Any], id=None) -> Dict[str, Any]: + if id is not None: + assert data['id'] == id + assert data['jsonrpc'] == '2.0' + assert 'result' in data + assert 'error' not in data + return data['result'] + + def is_async_result(self, data: Dict[str, Any], id=None) -> str: + result = self.is_result(data, id) + assert 'request_token' in result + return result['request_token'] + + def is_error(self, data: Dict[str, Any], id=None) -> Dict[str, Any]: + if id is not None: + assert data['id'] == id + assert data['jsonrpc'] == '2.0' + assert 'result' not in data + assert 'error' in data + return data['error'] + + def async_wait(self, token: str, timeout: int = 60, status='success') -> Dict[str, Any]: + start = time.time() + while True: + time.sleep(0.5) + response = self.poll(token) + if 'error' in response: + return response + result = self.is_result(response) + assert 'status' in result + if result['status'] == status: + return response + delta = (time.time() - start) + assert timeout > delta, \ + f'At time {delta}, never saw {status}.\nLast response: {result}' + + +def _first_server(cwd, cli_vars, profiles_dir, criteria): + stored = None + for _ in range(5): + port = random.randint(20000, 65535) + + proc = ServerProcess( + cwd=cwd, + cli_vars=cli_vars, + profiles_dir=str(profiles_dir), + port=port, + criteria=criteria, + ) + try: + proc.start() + except NoServerException as exc: + stored = exc + else: + return proc + if stored: + raise stored + + +@contextmanager +def rpc_server(project_dir, schema, profiles_dir, criteria='ready'): + if isinstance(criteria, str): + criteria = (criteria,) + else: + criteria = tuple(criteria) + + cli_vars = '{{test_run_schema: {}}}'.format(schema) + + proc = _first_server(project_dir, cli_vars, profiles_dir, criteria) + yield proc + if proc.is_alive(): + os.kill(proc.pid, signal.SIGKILL) + proc.join() + + +@pytest.fixture +def unique_schema() -> str: + return "test{}{:04}".format(int(time.time()), random.randint(0, 9999)) + + +@pytest.fixture +def profiles_dir(tmpdir): + return tmpdir.mkdir('profile') + + +@pytest.fixture +def postgres_profile_data(unique_schema): + return { + 'config': { + 'send_anonymous_usage_stats': False + }, + 'test': { + 'outputs': { + 'default': { + 'type': 'postgres', + 'threads': 4, + 'host': 'database', + 'port': 5432, + 'user': 'root', + 'pass': 'password', + 'dbname': 'dbt', + 'schema': unique_schema, + }, + }, + 'target': 'default' + } + } + + +@pytest.fixture +def postgres_profile(profiles_dir, postgres_profile_data) -> Dict[str, Any]: + path = os.path.join(profiles_dir, 'profiles.yml') + with open(path, 'w') as fp: + fp.write(yaml.safe_dump(postgres_profile_data)) + return postgres_profile_data + + +@pytest.fixture +def project_dir(tmpdir): + return tmpdir.mkdir('project') + + +class ProjectDefinition: + def __init__( + self, + name='test', + version='0.1.0', + profile='test', + project_data=None, + packages=None, + models=None, + macros=None, + ): + self.project = { + 'name': name, + 'version': version, + 'profile': profile, + } + if project_data: + self.project.update(project_data) + self.packages = packages + self.models = models + self.macros = macros + + def _write_recursive(self, path, inputs): + for name, value in inputs.items(): + if name.endswith('.sql'): + path.join(name).write(value) + elif name.endswith('.yml'): + if isinstance(value, str): + data = value + else: + data = yaml.safe_dump(value) + path.join(name).write(data) + else: + self._write_recursive(path.mkdir(name), value) + + def write_packages(self, project_dir, remove=False): + if remove: + project_dir.join('packages.yml').remove() + if self.packages is not None: + if isinstance(self.packages, str): + data = self.packages + else: + data = yaml.safe_dump(self.packages) + project_dir.join('packages.yml').write(data) + + def write_config(self, project_dir, remove=False): + cfg = project_dir.join('dbt_project.yml') + if remove: + cfg.remove() + cfg.write(yaml.safe_dump(self.project)) + + def write_models(self, project_dir, remove=False): + if remove: + project_dir.join('models').remove() + + if self.models is not None: + self._write_recursive(project_dir.mkdir('models'), self.models) + + def write_macros(self, project_dir, remove=False): + if remove: + project_dir.join('macros').remove() + + if self.macros is not None: + self._write_recursive(project_dir.mkdir('macros'), self.macros) + + def write_to(self, project_dir, remove=False): + if remove: + project_dir.remove() + project_dir.mkdir() + self.write_packages(project_dir) + self.write_config(project_dir) + self.write_models(project_dir) + self.write_macros(project_dir) + + +class TestArgs: + def __init__(self, profiles_dir, which='run-operation', kwargs={}): + self.which = which + self.single_threaded = False + self.profiles_dir = profiles_dir + self.profile = None + self.target = None + self.__dict__.update(kwargs) + + +def execute(adapter, sql): + with adapter.connection_named('rpc-tests') as conn: + with conn.handle.cursor() as cursor: + try: + cursor.execute(sql) + conn.handle.commit() + + except Exception as e: + if conn.handle and conn.handle.closed == 0: + conn.handle.rollback() + print(sql) + print(e) + raise + finally: + conn.transaction_open = False + + +@contextmanager +def built_schema(project_dir, schema, profiles_dir, test_kwargs, project_def): + # make our args, write our project out + args = TestArgs(profiles_dir=profiles_dir, kwargs=test_kwargs) + project_def.write_to(project_dir) + # build a config of our own + os.chdir(project_dir) + start = os.getcwd() + try: + cfg = RuntimeConfig.from_args(args) + finally: + os.chdir(start) + register_adapter(cfg) + adapter = get_adapter(cfg) + execute(adapter, 'drop schema if exists {} cascade'.format(schema)) + execute(adapter, 'create schema {}'.format(schema)) + yield + adapter = get_adapter(cfg) + adapter.cleanup_connections() + execute(adapter, 'drop schema if exists {} cascade'.format(schema)) diff --git a/test/rpc/test_base.py b/test/rpc/test_base.py new file mode 100644 index 00000000000..951dbc4edc2 --- /dev/null +++ b/test/rpc/test_base.py @@ -0,0 +1,271 @@ +import time +from .fixtures import ( + ProjectDefinition, rpc_server, Querier, project_dir, profiles_dir, + postgres_profile, unique_schema, postgres_profile_data, built_schema, +) + + +def test_rpc_basics(project_dir, profiles_dir, postgres_profile, unique_schema): + project = ProjectDefinition( + models={'my_model.sql': 'select 1 as id'} + ) + server_ctx = rpc_server( + project_dir=project_dir, schema=unique_schema, profiles_dir=profiles_dir + ) + schema_ctx = built_schema( + project_dir=project_dir, schema=unique_schema, profiles_dir=profiles_dir, test_kwargs={}, project_def=project, + ) + with schema_ctx, server_ctx as server: + querier = Querier(server) + + token = querier.is_async_result(querier.run_sql('select 1 as id')) + querier.is_result(querier.async_wait(token)) + + token = querier.is_async_result(querier.run()) + querier.is_result(querier.async_wait(token)) + + token = querier.is_async_result(querier.run_sql('select * from {{ ref("my_model") }}')) + querier.is_result(querier.async_wait(token)) + + token = querier.is_async_result(querier.run_sql('select * from {{ reff("my_model") }}')) + querier.is_error(querier.async_wait(token)) + + +def deps_with_packages(packages, bad_packages, project_dir, profiles_dir, schema): + project = ProjectDefinition( + models={ + 'my_model.sql': 'select 1 as id', + }, + packages={'packages': packages}, + ) + server_ctx = rpc_server( + project_dir=project_dir, schema=schema, profiles_dir=profiles_dir + ) + schema_ctx = built_schema( + project_dir=project_dir, schema=schema, profiles_dir=profiles_dir, test_kwargs={}, project_def=project, + ) + with schema_ctx, server_ctx as server: + querier = Querier(server) + + # we should be able to run sql queries at startup + token = querier.is_async_result(querier.run_sql('select 1 as id')) + querier.is_result(querier.async_wait(token)) + + # the status should be something positive + querier.is_result(querier.status()) + + # deps should pass + token = querier.is_async_result(querier.deps()) + querier.is_result(querier.async_wait(token)) + + # queries should work after deps + tok1 = querier.is_async_result(querier.run()) + tok2 = querier.is_async_result(querier.run_sql('select 1 as id')) + + querier.is_result(querier.async_wait(tok2)) + querier.is_result(querier.async_wait(tok1)) + + # now break the project + project.packages['packages'] = bad_packages + project.write_packages(project_dir, remove=True) + + # queries should still work because we haven't reloaded + tok1 = querier.is_async_result(querier.run()) + tok2 = querier.is_async_result(querier.run_sql('select 1 as id')) + + querier.is_result(querier.async_wait(tok2)) + querier.is_result(querier.async_wait(tok1)) + + # now run deps again, it should be sad + token = querier.is_async_result(querier.deps()) + querier.is_error(querier.async_wait(token)) + # it should also not be running. + result = querier.is_result(querier.ps(active=True, completed=False)) + assert result['rows'] == [] + + # fix packages again + project.packages['packages'] = packages + project.write_packages(project_dir, remove=True) + # keep queries broken, we haven't run deps yet + querier.is_error(querier.run()) + + # deps should pass now + token = querier.is_async_result(querier.deps()) + querier.is_result(querier.async_wait(token)) + querier.is_result(querier.status()) + + tok1 = querier.is_async_result(querier.run()) + tok2 = querier.is_async_result(querier.run_sql('select 1 as id')) + + querier.is_result(querier.async_wait(tok2)) + querier.is_result(querier.async_wait(tok1)) + + +def test_rpc_deps_packages(project_dir, profiles_dir, postgres_profile, unique_schema): + packages = [{ + 'package': 'fishtown-analytics/dbt_utils', + 'version': '0.2.1', + }] + bad_packages = [{ + 'package': 'fishtown-analytics/dbt_util', + 'version': '0.2.1', + }] + deps_with_packages(packages, bad_packages, project_dir, profiles_dir, unique_schema) + + +def test_rpc_deps_git(project_dir, profiles_dir, postgres_profile, unique_schema): + packages = [{ + 'git': 'https://github.com/fishtown-analytics/dbt-utils.git', + 'revision': '0.2.1' + }] + # if you use a bad URL, git thinks it's a private repo and prompts for auth + bad_packages = [{ + 'git': 'https://github.com/fishtown-analytics/dbt-utils.git', + 'revision': 'not-a-real-revision' + }] + deps_with_packages(packages, bad_packages, project_dir, profiles_dir, unique_schema) + + +bad_schema_yml = ''' +version: 2 +sources: + - name: test_source + loader: custom + schema: "{{ var('test_run_schema') }}" + tables: + - name: test_table + identifier: source + tests: + - relationships: + # this is invalid + - column_name: favorite_color + - to: ref('descendant_model') + - field: favorite_color +''' + +fixed_schema_yml = ''' +version: 2 +sources: + - name: test_source + loader: custom + schema: "{{ var('test_run_schema') }}" + tables: + - name: test_table + identifier: source +''' + + +def test_rpc_status_error(project_dir, profiles_dir, postgres_profile, unique_schema): + project = ProjectDefinition( + models={ + 'descendant_model.sql': 'select * from {{ source("test_source", "test_table") }}', + 'schema.yml': bad_schema_yml, + } + ) + server_ctx = rpc_server( + project_dir=project_dir, schema=unique_schema, profiles_dir=profiles_dir, criteria='error', + ) + schema_ctx = built_schema( + project_dir=project_dir, schema=unique_schema, profiles_dir=profiles_dir, test_kwargs={}, project_def=project, + ) + with schema_ctx, server_ctx as server: + querier = Querier(server) + + # the status should be an error result + result = querier.is_result(querier.status()) + assert 'error' in result + assert 'message' in result['error'] + assert 'Invalid test config' in result['error']['message'] + assert 'status' in result + assert result['status'] == 'error' + assert 'logs' in result + logs = result['logs'] + assert len(logs) > 0 + for key in ('message', 'timestamp', 'levelname', 'level'): + assert key in logs[0] + assert 'pid' in result + assert server.pid == result['pid'] + + error = querier.is_error(querier.compile_sql('select 1 as id')) + assert 'code' in error + assert error['code'] == 10011 + assert 'message' in error + assert error['message'] == 'RPC server failed to compile project, call the "status" method for compile status' + assert 'data' in error + assert 'message' in error['data'] + assert 'Invalid test config' in error['data']['message'] + + # deps should fail because it still can't parse the manifest + token = querier.is_async_result(querier.deps()) + querier.is_error(querier.async_wait(token)) + + # and not resolve the issue + result = querier.is_result(querier.status()) + assert 'error' in result + assert 'message' in result['error'] + assert 'Invalid test config' in result['error']['message'] + + error = querier.is_error(querier.compile_sql('select 1 as id')) + assert 'code' in error + assert error['code'] == 10011 + + project.models['schema.yml'] = fixed_schema_yml + project.write_models(project_dir, remove=True) + + # deps should work + token = querier.is_async_result(querier.deps()) + querier.is_result(querier.async_wait(token)) + + result = querier.is_result(querier.status()) + assert result.get('error') is None + assert 'status' in result + assert result['status'] == 'ready' + + querier.is_result(querier.compile_sql('select 1 as id')) + + +def test_gc_change_interval(project_dir, profiles_dir, postgres_profile, unique_schema): + project = ProjectDefinition( + models={'my_model.sql': 'select 1 as id'} + ) + server_ctx = rpc_server( + project_dir=project_dir, schema=unique_schema, profiles_dir=profiles_dir + ) + schema_ctx = built_schema( + project_dir=project_dir, schema=unique_schema, profiles_dir=profiles_dir, test_kwargs={}, project_def=project, + ) + with schema_ctx, server_ctx as server: + querier = Querier(server) + + for _ in range(10): + token = querier.is_async_result(querier.run()) + querier.is_result(querier.async_wait(token)) + + result = querier.is_result(querier.ps(True, True)) + assert len(result['rows']) == 10 + + result = querier.is_result(querier.gc(settings=dict(maxsize=1000, reapsize=5, auto_reap_age=0.1))) + + for k in ('deleted', 'missing', 'running'): + assert k in result + assert len(result[k]) == 0 + + time.sleep(0.5) + + result = querier.is_result(querier.ps(True, True)) + assert len(result['rows']) == 0 + + result = querier.is_result(querier.gc(settings=dict(maxsize=2, reapsize=5, auto_reap_age=100000))) + for k in ('deleted', 'missing', 'running'): + assert k in result + assert len(result[k]) == 0 + + time.sleep(0.5) + + for _ in range(10): + token = querier.is_async_result(querier.run()) + querier.is_result(querier.async_wait(token)) + + time.sleep(0.5) + result = querier.is_result(querier.ps(True, True)) + assert len(result['rows']) == 2 diff --git a/test/unit/test_deps.py b/test/unit/test_deps.py index 14c7f392ef4..43161d0af26 100644 --- a/test/unit/test_deps.py +++ b/test/unit/test_deps.py @@ -1,12 +1,18 @@ import unittest from unittest import mock +import dbt.deps import dbt.exceptions -from dbt.task.deps import ( - GitUnpinnedPackage, LocalUnpinnedPackage, RegistryUnpinnedPackage, - LocalPackageContract, GitPackageContract, RegistryPackageContract, - resolve_packages +from dbt.deps.git import GitUnpinnedPackage +from dbt.deps.local import LocalUnpinnedPackage +from dbt.deps.registry import RegistryUnpinnedPackage +from dbt.deps.resolver import resolve_packages +from dbt.contracts.project import ( + LocalPackage, + GitPackage, + RegistryPackage, ) + from dbt.contracts.project import PackageConfig from dbt.semver import VersionSpecifier @@ -15,7 +21,7 @@ class TestLocalPackage(unittest.TestCase): def test_init(self): - a_contract = LocalPackageContract.from_dict({'local': '/path/to/package'}) + a_contract = LocalPackage.from_dict({'local': '/path/to/package'}) self.assertEqual(a_contract.local, '/path/to/package') a = LocalUnpinnedPackage.from_contract(a_contract) self.assertEqual(a.local, '/path/to/package') @@ -26,7 +32,7 @@ def test_init(self): class TestGitPackage(unittest.TestCase): def test_init(self): - a_contract = GitPackageContract.from_dict( + a_contract = GitPackage.from_dict( {'git': 'http://example.com', 'revision': '0.0.1'} ) self.assertEqual(a_contract.git, 'http://example.com') @@ -46,15 +52,15 @@ def test_init(self): def test_invalid(self): with self.assertRaises(ValidationError): - GitPackageContract.from_dict( + GitPackage.from_dict( {'git': 'http://example.com', 'version': '0.0.1'} ) def test_resolve_ok(self): - a_contract = GitPackageContract.from_dict( + a_contract = GitPackage.from_dict( {'git': 'http://example.com', 'revision': '0.0.1'} ) - b_contract = GitPackageContract.from_dict( + b_contract = GitPackage.from_dict( {'git': 'http://example.com', 'revision': '0.0.1', 'warn-unpinned': False} ) @@ -71,10 +77,10 @@ def test_resolve_ok(self): self.assertFalse(c_pinned.warn_unpinned) def test_resolve_fail(self): - a_contract = GitPackageContract.from_dict( + a_contract = GitPackage.from_dict( {'git': 'http://example.com', 'revision': '0.0.1'} ) - b_contract = GitPackageContract.from_dict( + b_contract = GitPackage.from_dict( {'git': 'http://example.com', 'revision': '0.0.2'} ) a = GitUnpinnedPackage.from_contract(a_contract) @@ -87,7 +93,7 @@ def test_resolve_fail(self): c.resolved() def test_default_revision(self): - a_contract = GitPackageContract.from_dict({'git': 'http://example.com'}) + a_contract = GitPackage.from_dict({'git': 'http://example.com'}) self.assertEqual(a_contract.revision, None) self.assertIs(a_contract.warn_unpinned, None) @@ -105,7 +111,7 @@ def test_default_revision(self): class TestHubPackage(unittest.TestCase): def setUp(self): - self.patcher = mock.patch('dbt.task.deps.registry') + self.patcher = mock.patch('dbt.deps.registry.registry') self.registry = self.patcher.start() self.index_cached = self.registry.index_cached self.get_available_versions = self.registry.get_available_versions @@ -136,7 +142,7 @@ def tearDown(self): self.patcher.stop() def test_init(self): - a_contract = RegistryPackageContract( + a_contract = RegistryPackage( package='fishtown-analytics-test/a', version='0.1.2', ) @@ -164,16 +170,16 @@ def test_init(self): def test_invalid(self): with self.assertRaises(ValidationError): - RegistryPackageContract.from_dict( + RegistryPackage.from_dict( {'package': 'namespace/name', 'key': 'invalid'} ) def test_resolve_ok(self): - a_contract = RegistryPackageContract( + a_contract = RegistryPackage( package='fishtown-analytics-test/a', version='0.1.2' ) - b_contract = RegistryPackageContract( + b_contract = RegistryPackage( package='fishtown-analytics-test/a', version='0.1.2' ) @@ -210,7 +216,7 @@ def test_resolve_ok(self): self.assertEqual(c_pinned.source_type(), 'hub') def test_resolve_missing_package(self): - a = RegistryUnpinnedPackage.from_contract(RegistryPackageContract( + a = RegistryUnpinnedPackage.from_contract(RegistryPackage( package='fishtown-analytics-test/b', version='0.1.2' )) @@ -221,7 +227,7 @@ def test_resolve_missing_package(self): self.assertEqual(msg, str(exc.exception)) def test_resolve_missing_version(self): - a = RegistryUnpinnedPackage.from_contract(RegistryPackageContract( + a = RegistryUnpinnedPackage.from_contract(RegistryPackage( package='fishtown-analytics-test/a', version='0.1.4' )) @@ -236,11 +242,11 @@ def test_resolve_missing_version(self): self.assertEqual(msg, str(exc.exception)) def test_resolve_conflict(self): - a_contract = RegistryPackageContract( + a_contract = RegistryPackage( package='fishtown-analytics-test/a', version='0.1.2' ) - b_contract = RegistryPackageContract( + b_contract = RegistryPackage( package='fishtown-analytics-test/a', version='0.1.3' ) @@ -257,11 +263,11 @@ def test_resolve_conflict(self): self.assertEqual(msg, str(exc.exception)) def test_resolve_ranges(self): - a_contract = RegistryPackageContract( + a_contract = RegistryPackage( package='fishtown-analytics-test/a', version='0.1.2' ) - b_contract = RegistryPackageContract( + b_contract = RegistryPackage( package='fishtown-analytics-test/a', version='<0.1.4' ) @@ -321,7 +327,7 @@ def package_version(self, name, version): class TestPackageSpec(unittest.TestCase): def setUp(self): - self.patcher = mock.patch('dbt.task.deps.registry') + self.patcher = mock.patch('dbt.deps.registry.registry') self.registry = self.patcher.start() self.mock_registry = MockRegistry(packages={ 'fishtown-analytics-test/a': { diff --git a/tox.ini b/tox.ini index 83eb19cb69b..50e570b051b 100644 --- a/tox.ini +++ b/tox.ini @@ -30,11 +30,13 @@ passenv = * setenv = HOME=/home/tox commands = /bin/bash -c '{envpython} -m pytest --durations 0 -v -m profile_postgres {posargs} -n4 test/integration/*' + /bin/bash -c '{envpython} -m pytest --durations 0 -v {posargs} -n4 test/rpc/*' deps = -e {toxinidir}/core -e {toxinidir}/plugins/postgres -r{toxinidir}/dev_requirements.txt + [testenv:integration-snowflake-py36] basepython = python3.6 passenv = * @@ -102,7 +104,8 @@ basepython = python3.7 passenv = * setenv = HOME=/home/tox -commands = /bin/bash -c '{envpython} -m pytest --durations 0 -v -m profile_postgres {posargs} -n4 test/integration/*' +commands = /bin/bash -c '{envpython} -m pytest --durations 0 -v -m profile_postgres {posargs} -n4 test/integration/*' && \ + /bin/bash -c '{envpython} -m pytest --durations 0 -v {posargs} -n4 test/rpc/*' deps = -e {toxinidir}/core -e {toxinidir}/plugins/postgres