diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index a48a97b41ea..52368161aab 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -184,7 +184,7 @@ def link_graph(self, linker, manifest): if cycle: raise RuntimeError("Found a cycle: {}".format(cycle)) - def compile(self, manifest): + def compile(self, manifest, write=True): linker = Linker() self.link_graph(linker, manifest) @@ -196,25 +196,35 @@ def compile(self, manifest): manifest.macros.items()): stats[node.resource_type] += 1 - self.write_graph_file(linker, manifest) + if write: + self.write_graph_file(linker, manifest) print_compile_stats(stats) return linker -def compile_manifest(config, manifest): +def compile_manifest(config, manifest, write=True): compiler = Compiler(config) compiler.initialize() - return compiler.compile(manifest) + return compiler.compile(manifest, write=write) -def compile_node(adapter, config, node, manifest, extra_context): +def _is_writable(node): + if not node.injected_sql: + return False + + if dbt.utils.is_type(node, NodeType.Archive): + return False + + return True + + +def compile_node(adapter, config, node, manifest, extra_context, write=True): compiler = Compiler(config) node = compiler.compile_node(node, manifest, extra_context) node = _inject_runtime_config(adapter, node, extra_context) - if(node.injected_sql is not None and - not (dbt.utils.is_type(node, NodeType.Archive))): + if write and _is_writable(node): logger.debug('Writing injected SQL for node "{}"'.format( node.unique_id)) diff --git a/core/dbt/config/__init__.py b/core/dbt/config/__init__.py index b5280511ef7..3a9b433d38b 100644 --- a/core/dbt/config/__init__.py +++ b/core/dbt/config/__init__.py @@ -1,22 +1,5 @@ from .renderer import ConfigRenderer -from .profile import Profile, UserConfig +from .profile import Profile, UserConfig, PROFILES_DIR from .project import Project -from .profile import read_profile -from .profile import PROFILES_DIR from .runtime import RuntimeConfig - - -def read_profiles(profiles_dir=None): - """This is only used in main, for some error handling""" - if profiles_dir is None: - profiles_dir = PROFILES_DIR - - raw_profiles = read_profile(profiles_dir) - - if raw_profiles is None: - profiles = {} - else: - profiles = {k: v for (k, v) in raw_profiles.items() if k != 'config'} - - return profiles diff --git a/core/dbt/config/profile.py b/core/dbt/config/profile.py index bb5c91cc246..835718ea25a 100644 --- a/core/dbt/config/profile.py +++ b/core/dbt/config/profile.py @@ -335,14 +335,12 @@ def from_raw_profiles(cls, raw_profiles, profile_name, cli_vars, ) @classmethod - def from_args(cls, args, project_profile_name=None, cli_vars=None): + def from_args(cls, args, project_profile_name=None): """Given the raw profiles as read from disk and the name of the desired profile if specified, return the profile component of the runtime config. :param args argparse.Namespace: The arguments as parsed from the cli. - :param cli_vars dict: The command-line variables passed as arguments, - as a dict. :param project_profile_name Optional[str]: The profile name, if specified in a project. :raises DbtProjectError: If there is no profile name specified in the @@ -352,9 +350,7 @@ def from_args(cls, args, project_profile_name=None, cli_vars=None): target could not be found. :returns Profile: The new Profile object. """ - if cli_vars is None: - cli_vars = parse_cli_vars(getattr(args, 'vars', '{}')) - + cli_vars = parse_cli_vars(getattr(args, 'vars', '{}')) threads_override = getattr(args, 'threads', None) target_override = getattr(args, 'target', None) raw_profiles = read_profile(args.profiles_dir) diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index 13d01599143..5fbf06a86a6 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -377,6 +377,10 @@ def from_project_root(cls, project_root, cli_vars): def from_current_directory(cls, cli_vars): return cls.from_project_root(os.getcwd(), cli_vars) + @classmethod + def from_args(cls, args): + return cls.from_current_directory(getattr(args, 'vars', '{}')) + def hashed_name(self): return hashlib.md5(self.project_name.encode('utf-8')).hexdigest() diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index ee654474a5b..a125be7fd40 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -171,16 +171,13 @@ def from_args(cls, args): :raises DbtProfileError: If the profile is invalid or missing. :raises ValidationException: If the cli variables are invalid. """ - cli_vars = parse_cli_vars(getattr(args, 'vars', '{}')) - # build the project and read in packages.yml - project = Project.from_current_directory(cli_vars) + project = Project.from_args(args) # build the profile profile = Profile.from_args( args=args, - project_profile_name=project.profile_name, - cli_vars=cli_vars + project_profile_name=project.profile_name ) return cls.from_parts( diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 99866cd8ddb..83548a5a0d7 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -409,3 +409,13 @@ def get_used_schemas(self): def get_used_databases(self): return frozenset(node.database for node in self.nodes.values()) + + def deepcopy(self, config=None): + return Manifest( + nodes={k: v.incorporate() for k, v in self.nodes.items()}, + macros={k: v.incorporate() for k, v in self.macros.items()}, + docs={k: v.incorporate() for k, v in self.docs.items()}, + generated_at=self.generated_at, + disabled=[n.incorporate() for n in self.disabled], + config=config + ) diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index 30de42ef695..b3b4ea76e2d 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -75,6 +75,7 @@ NodeType.Seed, # we need this if parse_node is going to handle archives. NodeType.Archive, + NodeType.RPCCall, ] }, }, diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index fc42b0930a5..5a6be39dce2 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -458,3 +458,79 @@ class FreshnessRunOutput(APIObject): def __init__(self, meta, sources): super(FreshnessRunOutput, self).__init__(meta=meta, sources=sources) + + +REMOTE_COMPILE_RESULT_CONTRACT = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'raw_sql': { + 'type': 'string', + }, + 'compiled_sql': { + 'type': 'string', + }, + 'timing': { + 'type': 'array', + 'items': TIMING_INFO_CONTRACT, + }, + }, + 'required': ['raw_sql', 'compiled_sql', 'timing'] +} + + +class RemoteCompileResult(APIObject): + SCHEMA = REMOTE_COMPILE_RESULT_CONTRACT + + def __init__(self, raw_sql, compiled_sql, timing=None, **kwargs): + if timing is None: + timing = [] + super(RemoteCompileResult, self).__init__( + raw_sql=raw_sql, + compiled_sql=compiled_sql, + timing=timing, + **kwargs + ) + + @property + def node(self): + return None + + @property + def error(self): + return None + + +REMOTE_RUN_RESULT_CONTRACT = deep_merge(REMOTE_COMPILE_RESULT_CONTRACT, { + 'properties': { + 'table': { + 'type': 'object', + 'properties': { + 'column_names': { + 'type': 'array', + 'items': {'type': 'string'}, + }, + 'rows': { + 'type': 'array', + # any item type is ok + }, + }, + 'required': ['rows', 'column_names'], + }, + }, + 'required': ['table'], +}) + + +class RemoteRunResult(RemoteCompileResult): + SCHEMA = REMOTE_RUN_RESULT_CONTRACT + + def __init__(self, raw_sql, compiled_sql, timing=None, table=None): + if table is None: + table = [] + super(RemoteRunResult, self).__init__( + raw_sql=raw_sql, + compiled_sql=compiled_sql, + timing=timing, + table=table + ) diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 10e80c6422a..7885d0b771e 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -21,6 +21,10 @@ class InternalException(Exception): pass +class RPCException(Exception): + pass + + class RuntimeException(RuntimeError, Exception): def __init__(self, msg, node=None): self.stack = [] diff --git a/core/dbt/logger.py b/core/dbt/logger.py index 6b2ab24b391..725f42fc771 100644 --- a/core/dbt/logger.py +++ b/core/dbt/logger.py @@ -43,9 +43,14 @@ logging.getLogger('google').setLevel(logging.INFO) logging.getLogger('snowflake.connector').setLevel(logging.INFO) logging.getLogger('parsedatetime').setLevel(logging.INFO) +# we never want to seek werkzeug logs +logging.getLogger('werkzeug').setLevel(logging.CRITICAL) # provide this for the cache. CACHE_LOGGER = logging.getLogger('dbt.cache') +# provide this for RPC connection logging +RPC_LOGGER = logging.getLogger('dbt.rpc') + # Redirect warnings through our logging setup # They will be logged to a file below diff --git a/core/dbt/main.py b/core/dbt/main.py index 7f9f36255e9..231c8e80802 100644 --- a/core/dbt/main.py +++ b/core/dbt/main.py @@ -5,6 +5,7 @@ import os.path import sys import traceback +from contextlib import contextmanager import dbt.version import dbt.flags as flags @@ -21,6 +22,7 @@ import dbt.task.serve as serve_task import dbt.task.freshness as freshness_task import dbt.task.run_operation as run_operation_task +from dbt.task.rpc_server import RPCServerTask from dbt.adapters.factory import reset_adapters import dbt.tracking @@ -30,8 +32,7 @@ import dbt.profiler from dbt.utils import ExitCodes -from dbt.config import Project, UserConfig, RuntimeConfig, PROFILES_DIR, \ - read_profiles +from dbt.config import Project, UserConfig, RuntimeConfig, PROFILES_DIR from dbt.exceptions import DbtProjectError, DbtProfileError, RuntimeException @@ -149,138 +150,60 @@ def handle_and_check(args): reset_adapters() - try: - task, res = run_from_args(parsed) - finally: - dbt.tracking.flush() - + task, res = run_from_args(parsed) success = task.interpret_results(res) return res, success -def get_nearest_project_dir(): - root_path = os.path.abspath(os.sep) - cwd = os.getcwd() - - while cwd != root_path: - project_file = os.path.join(cwd, "dbt_project.yml") - if os.path.exists(project_file): - return cwd - cwd = os.path.dirname(cwd) - - return None - - -def run_from_args(parsed): - task = None - cfg = None - - if parsed.which in ('init', 'debug'): - # bypass looking for a project file if we're running `dbt init` or - # `dbt debug` - task = parsed.cls(args=parsed) - else: - nearest_project_dir = get_nearest_project_dir() - if nearest_project_dir is None: - raise RuntimeException( - "fatal: Not a dbt project (or any of the parent directories). " - "Missing dbt_project.yml file" - ) - - os.chdir(nearest_project_dir) - - res = invoke_dbt(parsed) - if res is None: - raise RuntimeException("Could not run dbt") - else: - task, cfg = res - - log_path = None - - if cfg is not None: - log_path = cfg.log_path - - initialize_logger(parsed.debug, log_path) - logger.debug("Tracking: {}".format(dbt.tracking.active_user.state())) - - dbt.tracking.track_invocation_start(config=cfg, args=parsed) - - results = run_from_task(task, cfg, parsed) - - return task, results - - -def run_from_task(task, cfg, parsed_args): - result = None +@contextmanager +def track_run(task): + dbt.tracking.track_invocation_start(config=task.config, args=task.args) try: - result = task.run() + yield dbt.tracking.track_invocation_end( - config=cfg, args=parsed_args, result_type="ok" + config=task.config, args=task.args, result_type="ok" ) except (dbt.exceptions.NotImplementedException, dbt.exceptions.FailedToConnectException) as e: logger.info('ERROR: {}'.format(e)) dbt.tracking.track_invocation_end( - config=cfg, args=parsed_args, result_type="error" + config=task.config, args=task.args, result_type="error" ) except Exception as e: dbt.tracking.track_invocation_end( - config=cfg, args=parsed_args, result_type="error" + config=task.config, args=task.args, result_type="error" ) raise + finally: + dbt.tracking.flush() - return result - - -def invoke_dbt(parsed): - task = None - cfg = None +def run_from_args(parsed): log_cache_events(getattr(parsed, 'log_cache_events', False)) + update_flags(parsed) + logger.info("Running with dbt{}".format(dbt.version.installed)) - try: - if parsed.which in {'deps', 'clean'}: - # deps doesn't need a profile, so don't require one. - cfg = Project.from_current_directory(getattr(parsed, 'vars', '{}')) - elif parsed.which != 'debug': - # for debug, we will attempt to load the various configurations as - # part of the task, so just leave cfg=None. - cfg = RuntimeConfig.from_args(parsed) - except DbtProjectError as e: - logger.info("Encountered an error while reading the project:") - logger.info(dbt.compat.to_string(e)) - - dbt.tracking.track_invalid_invocation( - config=cfg, - args=parsed, - result_type=e.result_type) - - return None - except DbtProfileError as e: - logger.info("Encountered an error while reading profiles:") - logger.info(" ERROR {}".format(str(e))) - - all_profiles = read_profiles(parsed.profiles_dir).keys() - - if len(all_profiles) > 0: - logger.info("Defined profiles:") - for profile in all_profiles: - logger.info(" - {}".format(profile)) - else: - logger.info("There are no profiles defined in your " - "profiles.yml file") + # this will convert DbtConfigErrors into RuntimeExceptions + task = parsed.cls.from_args(args=parsed) + logger.debug("running dbt with arguments %s", parsed) + + log_path = None + if task.config is not None: + log_path = getattr(task.config, 'log_path', None) + initialize_logger(parsed.debug, log_path) + logger.debug("Tracking: {}".format(dbt.tracking.active_user.state())) + + results = None - logger.info(PROFILES_HELP_MESSAGE) + with track_run(task): + results = task.run() - dbt.tracking.track_invalid_invocation( - config=cfg, - args=parsed, - result_type=e.result_type) + return task, results - return None +def update_flags(parsed): flags.NON_DESTRUCTIVE = getattr(parsed, 'non_destructive', False) flags.USE_CACHE = getattr(parsed, 'use_cache', True) @@ -298,12 +221,6 @@ def invoke_dbt(parsed): elif arg_full_refresh: flags.FULL_REFRESH = True - logger.debug("running dbt with arguments %s", parsed) - - task = parsed.cls(args=parsed, config=cfg) - - return task, cfg - def _build_base_subparser(): base_subparser = argparse.ArgumentParser(add_help=False) @@ -479,7 +396,7 @@ def _build_docs_generate_subparser(subparsers, base_subparser): return generate_sub -def _add_common_arguments(*subparsers): +def _add_selection_arguments(*subparsers): for sub in subparsers: sub.add_argument( '-m', @@ -498,15 +415,10 @@ def _add_common_arguments(*subparsers): Specify the models to exclude. """ ) - sub.add_argument( - '--threads', - type=int, - required=False, - help=""" - Specify number of threads to use while executing models. Overrides - settings in profiles.yml. - """ - ) + + +def _add_table_mutability_arguments(*subparsers): + for sub in subparsers: sub.add_argument( '--non-destructive', action='store_true', @@ -522,6 +434,19 @@ def _add_common_arguments(*subparsers): If specified, DBT will drop incremental models and fully-recalculate the incremental table from the model definition. """) + + +def _add_common_arguments(*subparsers): + for sub in subparsers: + sub.add_argument( + '--threads', + type=int, + required=False, + help=""" + Specify number of threads to use while executing models. Overrides + settings in profiles.yml. + """ + ) sub.add_argument( '--no-version-check', dest='version_check', @@ -584,32 +509,6 @@ def _build_test_subparser(subparsers, base_subparser): action='store_true', help='Run constraint validations from schema.yml files' ) - sub.add_argument( - '--threads', - type=int, - required=False, - help=""" - Specify number of threads to use while executing tests. Overrides - settings in profiles.yml - """ - ) - sub.add_argument( - '-m', - '--models', - required=False, - nargs='+', - help=""" - Specify the models to test. - """ - ) - sub.add_argument( - '--exclude', - required=False, - nargs='+', - help=""" - Specify the models to exclude from testing. - """ - ) sub.set_defaults(cls=test_task.TestTask, which='test') return sub @@ -645,6 +544,30 @@ def _build_source_snapshot_freshness_subparser(subparsers, base_subparser): return sub +def _build_rpc_subparser(subparsers, base_subparser): + sub = subparsers.add_parser( + 'rpc', + parents=[base_subparser], + help='Start a json-rpc server', + ) + sub.add_argument( + '--host', + default='0.0.0.0', + help='Specify the host to listen on for the rpc server.' + ) + sub.add_argument( + '--port', + default=8580, + type=int, + help='Specify the port number for the rpc server.' + ) + sub.set_defaults(cls=RPCServerTask, which='rpc') + # the rpc task does a 'compile', so we need these attributes to exist, but + # we don't want users to be allowed to set them. + sub.set_defaults(models=None, exclude=None) + return sub + + def parse_args(args): p = DBTArgumentParser( prog='dbt: data build tool', @@ -714,14 +637,21 @@ def parse_args(args): _build_deps_subparser(subs, base_subparser) _build_archive_subparser(subs, base_subparser) + rpc_sub = _build_rpc_subparser(subs, base_subparser) run_sub = _build_run_subparser(subs, base_subparser) compile_sub = _build_compile_subparser(subs, base_subparser) generate_sub = _build_docs_generate_subparser(docs_subs, base_subparser) - _add_common_arguments(run_sub, compile_sub, generate_sub) + test_sub = _build_test_subparser(subs, base_subparser) + # --threads, --no-version-check + _add_common_arguments(run_sub, compile_sub, generate_sub, test_sub, + rpc_sub) + # --models, --exclude + _add_selection_arguments(run_sub, compile_sub, generate_sub, test_sub) + # --full-refresh, --non-destructive + _add_table_mutability_arguments(run_sub, compile_sub) _build_seed_subparser(subs, base_subparser) _build_docs_serve_subparser(docs_subs, base_subparser) - _build_test_subparser(subs, base_subparser) _build_source_snapshot_freshness_subparser(source_subs, base_subparser) sub = subs.add_parser( diff --git a/core/dbt/node_runners.py b/core/dbt/node_runners.py index b949f0221f5..9e0cb61a898 100644 --- a/core/dbt/node_runners.py +++ b/core/dbt/node_runners.py @@ -5,8 +5,9 @@ from dbt.node_types import NodeType, RunHookType from dbt.adapters.factory import get_adapter from dbt.contracts.results import RunModelResult, collect_timing_info, \ - SourceFreshnessResult, PartialResult + SourceFreshnessResult, PartialResult, RemoteCompileResult, RemoteRunResult from dbt.compilation import compile_node +from dbt.utils import timestring import dbt.clients.jinja import dbt.context.runtime @@ -493,3 +494,66 @@ def print_result_line(self, result): schema_name, self.node_index, self.num_nodes) + + +class RPCCompileRunner(CompileRunner): + def __init__(self, config, adapter, node, node_index, num_nodes): + super(RPCCompileRunner, self).__init__(config, adapter, node, + node_index, num_nodes) + + def before_execute(self): + pass + + def after_execute(self, result): + pass + + def compile(self, manifest): + return compile_node(self.adapter, self.config, self.node, manifest, {}, + write=False) + + def execute(self, compiled_node, manifest): + return RemoteCompileResult( + raw_sql=compiled_node.raw_sql, + compiled_sql=compiled_node.injected_sql + ) + + def error_result(self, node, error, start_time, timing_info): + raise dbt.exceptions.RPCException(error) + + def ephemeral_result(self, node, start_time, timing_info): + raise dbt.exceptions.NotImplementedException( + 'cannot execute ephemeral nodes remotely!' + ) + + def from_run_result(self, result, start_time, timing_info): + timing = [t.serialize() for t in timing_info] + return RemoteCompileResult( + raw_sql=result.raw_sql, + compiled_sql=result.compiled_sql, + timing=timing + ) + + +class RPCExecuteRunner(RPCCompileRunner): + def from_run_result(self, result, start_time, timing_info): + timing = [t.serialize() for t in timing_info] + return RemoteRunResult( + raw_sql=result.raw_sql, + compiled_sql=result.compiled_sql, + table=result.table, + timing=timing + ) + + def execute(self, compiled_node, manifest): + status, table = self.adapter.execute(compiled_node.injected_sql, + fetch=True) + table = { + 'column_names': list(table.column_names), + 'rows': [list(row) for row in table] + } + + return RemoteRunResult( + raw_sql=compiled_node.raw_sql, + compiled_sql=compiled_node.injected_sql, + table=table + ) diff --git a/core/dbt/node_types.py b/core/dbt/node_types.py index 4f097ab1070..a633b67f2dc 100644 --- a/core/dbt/node_types.py +++ b/core/dbt/node_types.py @@ -10,6 +10,7 @@ class NodeType(object): Seed = 'seed' Documentation = 'documentation' Source = 'source' + RPCCall = 'rpc' @classmethod def executable(cls): @@ -21,6 +22,7 @@ def executable(cls): cls.Operation, cls.Seed, cls.Documentation, + cls.RPCCall, ] @classmethod diff --git a/core/dbt/parser/analysis.py b/core/dbt/parser/analysis.py index 5d218544983..c466ead1cfe 100644 --- a/core/dbt/parser/analysis.py +++ b/core/dbt/parser/analysis.py @@ -7,3 +7,8 @@ class AnalysisParser(BaseSqlParser): @classmethod def get_compiled_path(cls, name, relative_path): return os.path.join('analysis', relative_path) + + +class RPCCallParser(AnalysisParser): + def get_compiled_path(cls, name, relative_path): + return os.path.join('rpc', relative_path) diff --git a/core/dbt/parser/base_sql.py b/core/dbt/parser/base_sql.py index d6d7322a423..d412e2e6f9e 100644 --- a/core/dbt/parser/base_sql.py +++ b/core/dbt/parser/base_sql.py @@ -62,8 +62,22 @@ def load_and_parse(self, package_name, root_dir, relative_dirs, return self.parse_sql_nodes(result, tags) - def parse_sql_nodes(self, nodes, tags=None): + def parse_sql_node(self, node_dict, tags=None): + if tags is None: + tags = [] + + node = UnparsedNode(**node_dict) + package_name = node.package_name + + unique_id = self.get_path(node.resource_type, + package_name, + node.name) + project = self.all_projects.get(package_name) + node_parsed = self.parse_node(node, unique_id, project, tags=tags) + return unique_id, node_parsed + + def parse_sql_nodes(self, nodes, tags=None): if tags is None: tags = [] @@ -71,18 +85,10 @@ def parse_sql_nodes(self, nodes, tags=None): disabled = [] for n in nodes: - node = UnparsedNode(**n) - package_name = node.package_name - - node_path = self.get_path(node.resource_type, - package_name, - node.name) - - project = self.all_projects.get(package_name) - node_parsed = self.parse_node(node, node_path, project, tags=tags) + node_path, node_parsed = self.parse_sql_node(n, tags) # Ignore disabled nodes - if not node_parsed['config']['enabled']: + if not node_parsed.config['enabled']: disabled.append(node_parsed) continue diff --git a/core/dbt/parser/util.py b/core/dbt/parser/util.py index c4c2245746d..92c90fe68a0 100644 --- a/core/dbt/parser/util.py +++ b/core/dbt/parser/util.py @@ -124,96 +124,127 @@ def _get_node_column(cls, node, column_name): return column + @classmethod + def process_docs_for_node(cls, manifest, current_project, node): + target_doc = None + target_doc_name = None + target_doc_package = None + for docref in node.get('docrefs', []): + column_name = docref.get('column_name') + if column_name is None: + description = node.get('description', '') + else: + column = cls._get_node_column(node, column_name) + description = column.get('description', '') + target_doc_name = docref['documentation_name'] + target_doc_package = docref['documentation_package'] + context = { + 'doc': docs(node, manifest, current_project, column_name), + } + + # At this point, target_doc is a ParsedDocumentation, and we + # know that our documentation string has a 'docs("...")' + # pointing at it. We want to render it. + description = dbt.clients.jinja.get_rendered(description, + context) + # now put it back. + if column_name is None: + node.set('description', description) + else: + column['description'] = description + @classmethod def process_docs(cls, manifest, current_project): - for _, node in manifest.nodes.items(): - target_doc = None - target_doc_name = None - target_doc_package = None - for docref in node.get('docrefs', []): - column_name = docref.get('column_name') - if column_name is None: - description = node.get('description', '') - else: - column = cls._get_node_column(node, column_name) - description = column.get('description', '') - target_doc_name = docref['documentation_name'] - target_doc_package = docref['documentation_package'] - context = { - 'doc': docs(node, manifest, current_project, column_name), - } - - # At this point, target_doc is a ParsedDocumentation, and we - # know that our documentation string has a 'docs("...")' - # pointing at it. We want to render it. - description = dbt.clients.jinja.get_rendered(description, - context) - # now put it back. - if column_name is None: - node.set('description', description) - else: - column['description'] = description + for node in manifest.nodes.values(): + cls.process_docs_for_node(manifest, current_project, node) return manifest @classmethod - def process_refs(cls, manifest, current_project): - for _, node in manifest.nodes.items(): - target_model = None - target_model_name = None - target_model_package = None - - for ref in node.refs: - if len(ref) == 1: - target_model_name = ref[0] - elif len(ref) == 2: - target_model_package, target_model_name = ref - - target_model = cls.resolve_ref( - manifest, - target_model_name, - target_model_package, - current_project, - node.get('package_name')) - - if target_model is None or target_model is cls.DISABLED: - # This may raise. Even if it doesn't, we don't want to add - # this node to the graph b/c there is no destination node - node.config['enabled'] = False - dbt.utils.invalid_ref_fail_unless_test( - node, target_model_name, target_model_package, - disabled=(target_model is cls.DISABLED) - ) - - continue - - target_model_id = target_model.get('unique_id') - - node.depends_on['nodes'].append(target_model_id) - manifest.nodes[node['unique_id']] = node + def process_refs_for_node(cls, manifest, current_project, node): + """Given a manifest and a node in that manifest, process its refs""" + target_model = None + target_model_name = None + target_model_package = None + + for ref in node.refs: + if len(ref) == 1: + target_model_name = ref[0] + elif len(ref) == 2: + target_model_package, target_model_name = ref + + target_model = cls.resolve_ref( + manifest, + target_model_name, + target_model_package, + current_project, + node.get('package_name')) + + if target_model is None or target_model is cls.DISABLED: + # This may raise. Even if it doesn't, we don't want to add + # this node to the graph b/c there is no destination node + node.config['enabled'] = False + dbt.utils.invalid_ref_fail_unless_test( + node, target_model_name, target_model_package, + disabled=(target_model is cls.DISABLED) + ) + + continue + + target_model_id = target_model.get('unique_id') + + node.depends_on['nodes'].append(target_model_id) + manifest.nodes[node['unique_id']] = node + @classmethod + def process_refs(cls, manifest, current_project): + for node in manifest.nodes.values(): + cls.process_refs_for_node(manifest, current_project, node) return manifest @classmethod - def process_sources(cls, manifest, current_project): - for _, node in manifest.nodes.items(): - target_source = None - for source_name, table_name in node.sources: - target_source = cls.resolve_source( - manifest, + def process_sources_for_node(cls, manifest, current_project, node): + target_source = None + for source_name, table_name in node.sources: + target_source = cls.resolve_source( + manifest, + source_name, + table_name, + current_project, + node.get('package_name')) + + if target_source is None: + # this folows the same pattern as refs + node.config['enabled'] = False + dbt.utils.invalid_source_fail_unless_test( + node, source_name, - table_name, - current_project, - node.get('package_name')) - - if target_source is None: - # this folows the same pattern as refs - node.config['enabled'] = False - dbt.utils.invalid_source_fail_unless_test( - node, - source_name, - table_name) - continue - target_source_id = target_source.unique_id - node.depends_on['nodes'].append(target_source_id) - manifest.nodes[node['unique_id']] = node + table_name) + continue + target_source_id = target_source.unique_id + node.depends_on['nodes'].append(target_source_id) + manifest.nodes[node['unique_id']] = node + + @classmethod + def process_sources(cls, manifest, current_project): + for node in manifest.nodes.values(): + cls.process_sources_for_node(manifest, current_project, node) + return manifest + + @classmethod + def add_new_refs(cls, manifest, current_project, node): + """Given a new node that is not in the manifest, copy the manifest and + insert the new node into it as if it were part of regular ref + processing + """ + manifest = manifest.deepcopy(config=current_project) + if node.unique_id in manifest.nodes: + # this should be _impossible_ due to the fact that rpc calls get + # a unique ID that starts with 'rpc'! + raise dbt.exceptions.raise_duplicate_resource_name( + manifest.nodes[node.unique_id], node + ) + manifest.nodes[node.unique_id] = node + cls.process_sources_for_node(manifest, current_project, node) + cls.process_refs_for_node(manifest, current_project, node) + cls.process_docs_for_node(manifest, current_project, node) return manifest diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py new file mode 100644 index 00000000000..4fe9611beed --- /dev/null +++ b/core/dbt/task/base.py @@ -0,0 +1,128 @@ +from abc import ABCMeta, abstractmethod +import os + +import six + +from dbt.config import RuntimeConfig, Project +from dbt.config.profile import read_profile, PROFILES_DIR +from dbt import flags +from dbt import tracking +from dbt.logger import GLOBAL_LOGGER as logger +import dbt.exceptions + + +class NoneConfig(object): + @classmethod + def from_args(cls, args): + return None + + +def read_profiles(profiles_dir=None): + """This is only used for some error handling""" + if profiles_dir is None: + profiles_dir = PROFILES_DIR + + raw_profiles = read_profile(profiles_dir) + + if raw_profiles is None: + profiles = {} + else: + profiles = {k: v for (k, v) in raw_profiles.items() if k != 'config'} + + return profiles + + +PROFILES_HELP_MESSAGE = """ +For more information on configuring profiles, please consult the dbt docs: + +https://docs.getdbt.com/docs/configure-your-profile +""" + + +@six.add_metaclass(ABCMeta) +class BaseTask(object): + ConfigType = NoneConfig + + def __init__(self, args, config): + self.args = args + self.config = config + + @classmethod + def from_args(cls, args): + try: + config = cls.ConfigType.from_args(args) + except dbt.exceptions.DbtProjectError as exc: + logger.info("Encountered an error while reading the project:") + logger.info(to_string(exc)) + + tracking.track_invalid_invocation( + args=args, + result_type=exc.result_type) + raise dbt.exceptions.RuntimeException('Could not run dbt') + except dbt.exceptions.DbtProfileError as exc: + logger.info("Encountered an error while reading profiles:") + logger.info(" ERROR {}".format(str(exc))) + + all_profiles = read_profiles(args.profiles_dir).keys() + + if len(all_profiles) > 0: + logger.info("Defined profiles:") + for profile in all_profiles: + logger.info(" - {}".format(profile)) + else: + logger.info("There are no profiles defined in your " + "profiles.yml file") + + logger.info(PROFILES_HELP_MESSAGE) + + tracking.track_invalid_invocation( + args=args, + result_type=exc.result_type) + raise dbt.exceptions.RuntimeException('Could not run dbt') + return cls(args, config) + + @abstractmethod + def run(self): + raise dbt.exceptions.NotImplementedException('Not Implemented') + + def interpret_results(self, results): + return True + + +def get_nearest_project_dir(): + root_path = os.path.abspath(os.sep) + cwd = os.getcwd() + + while cwd != root_path: + project_file = os.path.join(cwd, "dbt_project.yml") + if os.path.exists(project_file): + return cwd + cwd = os.path.dirname(cwd) + + return None + + +def move_to_nearest_project_dir(): + nearest_project_dir = get_nearest_project_dir() + if nearest_project_dir is None: + raise dbt.exceptions.RuntimeException( + "fatal: Not a dbt project (or any of the parent directories). " + "Missing dbt_project.yml file" + ) + + os.chdir(nearest_project_dir) + + +class RequiresProjectTask(BaseTask): + @classmethod + def from_args(cls, args): + move_to_nearest_project_dir() + return super(RequiresProjectTask, cls).from_args(args) + + +class ConfiguredTask(RequiresProjectTask): + ConfigType = RuntimeConfig + + +class ProjectOnlyTask(RequiresProjectTask): + ConfigType = Project diff --git a/core/dbt/task/base_task.py b/core/dbt/task/base_task.py deleted file mode 100644 index db8cedbff45..00000000000 --- a/core/dbt/task/base_task.py +++ /dev/null @@ -1,13 +0,0 @@ -import dbt.exceptions - - -class BaseTask(object): - def __init__(self, args, config=None): - self.args = args - self.config = config - - def run(self): - raise dbt.exceptions.NotImplementedException('Not Implemented') - - def interpret_results(self, results): - return True diff --git a/core/dbt/task/clean.py b/core/dbt/task/clean.py index f7b524057b8..ab0ef081b10 100644 --- a/core/dbt/task/clean.py +++ b/core/dbt/task/clean.py @@ -2,10 +2,10 @@ import os import shutil -from dbt.task.base_task import BaseTask +from dbt.task.base import ProjectOnlyTask -class CleanTask(BaseTask): +class CleanTask(ProjectOnlyTask): def __is_project_path(self, path): proj_path = os.path.abspath('.') diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index ac7f49ec2c8..3aae70d67da 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -1,11 +1,20 @@ -from dbt.node_runners import CompileRunner +import os + +from dbt.adapters.factory import get_adapter +from dbt.compilation import compile_manifest +from dbt.exceptions import RuntimeException +from dbt.loader import load_all_projects, GraphLoader +from dbt.node_runners import CompileRunner, RPCCompileRunner from dbt.node_types import NodeType +from dbt.parser.analysis import RPCCallParser +from dbt.parser.util import ParserUtils import dbt.ui.printer -from dbt.task.runnable import RunnableTask +from dbt.task.runnable import ManifestTask, GraphRunnableTask, RemoteCallable + +class CompileTask(GraphRunnableTask): -class CompileTask(RunnableTask): def raise_on_first_error(self): return True @@ -22,3 +31,64 @@ def get_runner_type(self): def task_end_messages(self, results): dbt.ui.printer.print_timestamped_line('Done.') + + +class RemoteCompileTask(CompileTask, RemoteCallable): + METHOD_NAME = 'compile' + + def __init__(self, args, config): + super(CompileTask, self).__init__(args, config) + self.parser = None + self._base_manifest = GraphLoader.load_all( + config, + internal_manifest=get_adapter(config).check_internal_manifest() + ) + + def get_runner_type(self): + return RPCCompileRunner + + def runtime_cleanup(self, selected_uids): + """Do some pre-run cleanup that is usually performed in Task __init__. + """ + self.run_count = 0 + self.num_nodes = len(selected_uids) + self.node_results = [] + self._skipped_children = {} + self._skipped_children = {} + self._raise_next_tick = None + + def handle_request(self, name, sql): + self.parser = RPCCallParser( + self.config, + all_projects=load_all_projects(self.config), + macro_manifest=self._base_manifest + ) + + sql = self.decode_sql(sql) + request_path = os.path.join(self.config.target_path, 'rpc', name) + node_dict = { + 'name': name, + 'root_path': request_path, + 'resource_type': NodeType.RPCCall, + 'path': name+'.sql', + 'original_file_path': 'from remote system', + 'package_name': self.config.project_name, + 'raw_sql': sql, + } + unique_id, node = self.parser.parse_sql_node(node_dict) + + self.manifest = ParserUtils.add_new_refs( + manifest=self._base_manifest, + current_project=self.config, + node=node + ) + # don't write our new, weird manifest! + self.linker = compile_manifest(self.config, self.manifest, write=False) + selected_uids = [node.unique_id] + self.runtime_cleanup(selected_uids) + self.job_queue = self.linker.as_graph_queue(self.manifest, + selected_uids) + + result = self.get_runner(node).safe_run(self.manifest) + + return result.serialize() diff --git a/core/dbt/task/debug.py b/core/dbt/task/debug.py index 6a141cd7b3c..fbc9542cce7 100644 --- a/core/dbt/task/debug.py +++ b/core/dbt/task/debug.py @@ -16,7 +16,7 @@ from dbt.clients.yaml_helper import load_yaml_text from dbt.ui.printer import green, red -from dbt.task.base_task import BaseTask +from dbt.task.base import BaseTask PROFILE_DIR_MESSAGE = """To view your profiles.yml file, run: @@ -59,7 +59,7 @@ class DebugTask(BaseTask): - def __init__(self, args, config=None): + def __init__(self, args, config): super(DebugTask, self).__init__(args, config) self.profiles_dir = getattr(self.args, 'profiles_dir', dbt.config.PROFILES_DIR) diff --git a/core/dbt/task/deps.py b/core/dbt/task/deps.py index 2fe91a8af66..3e282c25d7b 100644 --- a/core/dbt/task/deps.py +++ b/core/dbt/task/deps.py @@ -21,7 +21,7 @@ GIT_PACKAGE_CONTRACT, REGISTRY_PACKAGE_CONTRACT, \ REGISTRY_PACKAGE_METADATA_CONTRACT, PackageConfig -from dbt.task.base_task import BaseTask +from dbt.task.base import ProjectOnlyTask DOWNLOADS_PATH = None REMOVE_DOWNLOADS = False @@ -440,7 +440,7 @@ def _read_packages(project_yaml): return packages -class DepsTask(BaseTask): +class DepsTask(ProjectOnlyTask): def __init__(self, args, config=None): super(DepsTask, self).__init__(args=args, config=config) self._downloads_path = None diff --git a/core/dbt/task/freshness.py b/core/dbt/task/freshness.py index b3956daa917..75bbd4f5ba9 100644 --- a/core/dbt/task/freshness.py +++ b/core/dbt/task/freshness.py @@ -1,5 +1,5 @@ import os -from dbt.task.runnable import BaseRunnableTask +from dbt.task.runnable import GraphRunnableTask from dbt.node_runners import FreshnessRunner from dbt.node_types import NodeType from dbt.ui.printer import print_timestamped_line, print_run_result_error @@ -8,7 +8,7 @@ RESULT_FILE_NAME = 'sources.json' -class FreshnessTask(BaseRunnableTask): +class FreshnessTask(GraphRunnableTask): def result_path(self): if self.args.output: return os.path.realpath(self.args.output) diff --git a/core/dbt/task/init.py b/core/dbt/task/init.py index 9f8569b9481..ccece1ddca4 100644 --- a/core/dbt/task/init.py +++ b/core/dbt/task/init.py @@ -6,7 +6,7 @@ from dbt.logger import GLOBAL_LOGGER as logger -from dbt.task.base_task import BaseTask +from dbt.task.base import BaseTask STARTER_REPO = 'https://github.com/fishtown-analytics/dbt-starter-project.git' DOCS_URL = 'https://docs.getdbt.com/docs/configure-your-profile' diff --git a/core/dbt/task/rpc_server.py b/core/dbt/task/rpc_server.py new file mode 100644 index 00000000000..e78769eccb5 --- /dev/null +++ b/core/dbt/task/rpc_server.py @@ -0,0 +1,76 @@ +import json +import os + +from jsonrpc import Dispatcher, JSONRPCResponseManager + +from werkzeug.wrappers import Request, Response +from werkzeug.serving import run_simple + +from dbt.logger import RPC_LOGGER as logger +from dbt.task.base import ConfiguredTask +from dbt.task.compile import CompileTask, RemoteCompileTask +from dbt.task.run import RemoteRunTask +from dbt.utils import JSONEncoder + + +class RPCServerTask(ConfiguredTask): + def __init__(self, args, config, tasks=None): + super(RPCServerTask, self).__init__(args, config) + # compile locally + self.compile_task = CompileTask(args, config) + self.compile_task.run() + self.dispatcher = Dispatcher() + tasks = tasks or [RemoteCompileTask, RemoteRunTask] + for cls in tasks: + self.register(cls(args, config)) + + def register(self, task): + self.dispatcher.add_method(task.safe_handle_request, + name=task.METHOD_NAME) + + @property + def manifest(self): + return self.compile_task.manifest + + def run(self): + host = self.args.host + port = self.args.port + addr = (host, port) + + display_host = host + if host == '0.0.0.0': + display_host = 'localhost' + + logger.info( + 'Serving RPC server at {}:{}'.format(*addr) + ) + + logger.info( + 'Supported methods: {}'.format(list(self.dispatcher.keys())) + ) + + logger.info( + 'Send requests to http://{}:{}'.format(display_host, port) + ) + + run_simple(host, port, self.handle_request, + processes=self.config.threads) + + @Request.application + def handle_request(self, request): + msg = 'Received request ({0}) from {0.remote_addr}, data={0.data}' + logger.info(msg.format(request)) + # request_data is the request as a parsedjson object + response = JSONRPCResponseManager.handle( + request.data, self.dispatcher + ) + json_data = json.dumps(response.data, cls=JSONEncoder) + response = Response(json_data, mimetype='application/json') + # this looks and feels dumb, but our json encoder converts decimals and + # datetimes, and if we use the json_data itself the output looks silly + # because of escapes, so re-serialize it into valid JSON types for + # logging. + logger.info('sending response ({}) to {}, data={}'.format( + response, request.remote_addr, json.loads(json_data)) + ) + return response diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 7c9a7c6b418..3ab251aee1c 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -2,7 +2,7 @@ from dbt.logger import GLOBAL_LOGGER as logger from dbt.node_types import NodeType, RunHookType -from dbt.node_runners import ModelRunner +from dbt.node_runners import ModelRunner, RPCExecuteRunner import dbt.exceptions import dbt.flags @@ -11,7 +11,7 @@ from dbt.hooks import get_hook_dict from dbt.compilation import compile_node -from dbt.task.compile import CompileTask +from dbt.task.compile import CompileTask, RemoteCompileTask from dbt.utils import get_nodes_by_tags @@ -114,3 +114,10 @@ def get_runner_type(self): def task_end_messages(self, results): if results: dbt.ui.printer.print_run_end_messages(results) + + +class RemoteRunTask(RemoteCompileTask, RunTask): + METHOD_NAME = 'run' + + def get_runner_type(self): + return RPCExecuteRunner diff --git a/core/dbt/task/run_operation.py b/core/dbt/task/run_operation.py index a2ce4ff5c9b..6ea5d91bd98 100644 --- a/core/dbt/task/run_operation.py +++ b/core/dbt/task/run_operation.py @@ -1,6 +1,6 @@ from dbt.logger import GLOBAL_LOGGER as logger -from dbt.task.base_task import BaseTask +from dbt.task.base import BaseTask from dbt.adapters.factory import get_adapter from dbt.loader import GraphLoader diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 06694382d44..cfc4d22b69b 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -1,9 +1,17 @@ +import base64 import os +import re import time +from abc import abstractmethod +from multiprocessing import Process, Pipe +from multiprocessing.dummy import Pool as ThreadPool + +import six -from dbt.task.base_task import BaseTask +from dbt.task.base import ConfiguredTask from dbt.adapters.factory import get_adapter from dbt.logger import GLOBAL_LOGGER as logger +from dbt.compat import abstractclassmethod, to_unicode from dbt.compilation import compile_manifest from dbt.contracts.graph.manifest import CompileResultNode from dbt.contracts.results import ExecutionResult @@ -15,8 +23,6 @@ import dbt.graph.selector -from multiprocessing.dummy import Pool as ThreadPool - RESULT_FILE_NAME = 'run_results.json' MANIFEST_FILE_NAME = 'manifest.json' @@ -32,11 +38,20 @@ def load_manifest(config): return manifest -class BaseRunnableTask(BaseTask): +class ManifestTask(ConfiguredTask): def __init__(self, args, config): - super(BaseRunnableTask, self).__init__(args, config) + super(ManifestTask, self).__init__(args, config) self.manifest = None self.linker = None + + def _runtime_initialize(self): + self.manifest = load_manifest(self.config) + self.linker = compile_manifest(self.config, self.manifest) + + +class GraphRunnableTask(ManifestTask): + def __init__(self, args, config): + super(GraphRunnableTask, self).__init__(args, config) self.job_queue = None self._flattened_nodes = None @@ -46,12 +61,14 @@ def __init__(self, args, config): self._skipped_children = {} self._raise_next_tick = None - def _runtime_initialize(self): - self.manifest = load_manifest(self.config) - self.linker = compile_manifest(self.config, self.manifest) - + def select_nodes(self): selector = dbt.graph.selector.NodeSelector(self.linker, self.manifest) selected_nodes = selector.select(self.build_query()) + return selected_nodes + + def _runtime_initialize(self): + super(GraphRunnableTask, self)._runtime_initialize() + selected_nodes = self.select_nodes() self.job_queue = self.linker.as_graph_queue(self.manifest, selected_nodes) @@ -229,27 +246,8 @@ def after_run(self, adapter, results): def after_hooks(self, adapter, results, elapsed): pass - def task_end_messages(self, results): - raise dbt.exceptions.NotImplementedException('Not Implemented') - - def get_result(self, results, elapsed_time, generated_at): - raise dbt.exceptions.NotImplementedException('Not Implemented') - - def run(self): - """ - Run dbt for the query, based on the graph. - """ - self._runtime_initialize() + def execute_with_hooks(self, selected_uids): adapter = get_adapter(self.config) - - if len(self._flattened_nodes) == 0: - logger.info("WARNING: Nothing to do. Try checking your model " - "configs and model specification args") - return [] - else: - logger.info("") - - selected_uids = frozenset(n.unique_id for n in self._flattened_nodes) try: self.before_hooks(adapter) started = time.time() @@ -267,10 +265,28 @@ def run(self): elapsed_time=elapsed, generated_at=dbt.utils.timestring() ) + return result + + def run(self): + """ + Run dbt for the query, based on the graph. + """ + self._runtime_initialize() + + if len(self._flattened_nodes) == 0: + logger.info("WARNING: Nothing to do. Try checking your model " + "configs and model specification args") + return [] + else: + logger.info("") + + selected_uids = frozenset(n.unique_id for n in self._flattened_nodes) + result = self.execute_with_hooks(selected_uids) + result.write(self.result_path()) - self.task_end_messages(res) - return res + self.task_end_messages(result.results) + return result.results def interpret_results(self, results): if results is None: @@ -279,8 +295,6 @@ def interpret_results(self, results): failures = [r for r in results if r.error or r.failed] return len(failures) == 0 - -class RunnableTask(BaseRunnableTask): def get_model_schemas(self, selected_uids): schemas = set() for node in self.manifest.nodes.values(): @@ -320,3 +334,83 @@ def get_result(self, results, elapsed_time, generated_at): def task_end_messages(self, results): dbt.ui.printer.print_run_end_messages(results) + + +class RemoteCallable(object): + METHOD_NAME = None + is_async = False + + @abstractmethod + def handle_request(self, **kwargs): + raise dbt.exceptions.NotImplementedException( + 'from_kwargs not implemented' + ) + + def _subprocess_handle_request(self, conn, **kwargs): + error = None + result = None + try: + result = self.handle_request(**kwargs) + except dbt.exceptions.RuntimeException as exc: + logger.debug('dbt runtime exception', + exc_info=True) + # we have to convert this to a string for RPC responses + error = str(exc) + except dbt.exceptions.RPCException as exc: + error = str(exc) + except Exception as exc: + logger.debug('uncaught python exception', + exc_info=True) + error = str(exc) + conn.send([result, error]) + conn.close() + + def safe_handle_request(self, **kwargs): + # assumption here: we are within a thread/process already and can block + # however we like to enforce the timeout + timeout = kwargs.pop('timeout', None) + parent_conn, child_conn = Pipe() + proc = Process( + target=self._subprocess_handle_request, + args=(child_conn,), + kwargs=kwargs + ) + proc.start() + if parent_conn.poll(timeout): + result, error = parent_conn.recv() + else: + error = 'timed out after {}s'.format(timeout) + proc.terminate() + + parent_conn.close() + + proc.join() + if error: + raise dbt.exceptions.RPCException(error) + else: + return result + + def decode_sql(self, sql): + """Base64 decode a string. This should only be used for sql in calls. + + :param str sql: The base64 encoded form of the original utf-8 string + :return str: The decoded utf-8 string + """ + # JSON is defined as using "unicode", we'll go a step further and + # mandate utf-8 (though for the base64 part, it doesn't really matter!) + base64_sql_bytes = to_unicode(sql).encode('utf-8') + # in python3.x you can pass `validate=True` to b64decode to get this + # behavior. + if not re.match(b'^[A-Za-z0-9+/]*={0,2}$', base64_sql_bytes): + raise dbt.exceptions.RPCException( + 'invalid base64-encoded sql input: {!s}'.format(sql) + ) + + try: + sql_bytes = base64.b64decode(base64_sql_bytes) + except ValueError as exc: + raise dbt.exceptions.RPCException( + 'invalid base64-encoded sql input: {!s}'.format(exc) + ) + + return sql_bytes.decode('utf-8') diff --git a/core/dbt/task/serve.py b/core/dbt/task/serve.py index d8ce756b75c..7698a47cd58 100644 --- a/core/dbt/task/serve.py +++ b/core/dbt/task/serve.py @@ -6,10 +6,10 @@ from dbt.compat import SimpleHTTPRequestHandler, TCPServer from dbt.logger import GLOBAL_LOGGER as logger -from dbt.task.base_task import BaseTask +from dbt.task.base import ProjectOnlyTask -class ServeTask(BaseTask): +class ServeTask(ProjectOnlyTask): def run(self): os.chdir(self.config.target_path) diff --git a/core/dbt/utils.py b/core/dbt/utils.py index edf283bda6b..c5fc8977e98 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -453,6 +453,8 @@ class JSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, DECIMALS): return float(obj) + if isinstance(obj, datetime): + return obj.isoformat() return super(JSONEncoder, self).default(obj) diff --git a/core/setup.py b/core/setup.py index 3b859408299..ed063dd99fd 100644 --- a/core/setup.py +++ b/core/setup.py @@ -52,5 +52,7 @@ def read(fname): 'colorama==0.3.9', 'agate>=1.6,<2', 'jsonschema==2.6.0', + 'json-rpc>=1.12,<2', + 'werkzeug>=0.14.1,<0.15', ] ) diff --git a/test/integration/042_sources_test/test_sources.py b/test/integration/042_sources_test/test_sources.py index e44bb6a6a2f..111feb124c6 100644 --- a/test/integration/042_sources_test/test_sources.py +++ b/test/integration/042_sources_test/test_sources.py @@ -1,3 +1,4 @@ +import unittest from nose.plugins.attrib import attr from datetime import datetime, timedelta import json @@ -6,7 +7,7 @@ from dbt.exceptions import CompilationException from test.integration.base import DBTIntegrationTest, use_profile, AnyFloat, \ AnyStringWith - +from dbt.main import handle_and_check class BaseSourcesTest(DBTIntegrationTest): @property @@ -113,6 +114,7 @@ def test_source_childrens_parents(self): ) self.assertTableDoesNotExist('nonsource_descendant') + class TestSourceFreshness(BaseSourcesTest): def setUp(self): super(TestSourceFreshness, self).setUp() @@ -256,3 +258,276 @@ def test_postgres_malformed_schema_nonstrict_will_not_break_run(self): def test_postgres_malformed_schema_strict_will_break_run(self): with self.assertRaises(CompilationException): self.run_dbt_with_vars(['run'], strict=True) + + +import multiprocessing +from base64 import standard_b64encode as b64 +import json +import requests +import socket +import time +import os + + + +class ServerProcess(multiprocessing.Process): + def __init__(self, cli_vars=None): + self.port = 22991 + handle_and_check_args = [ + '--strict', 'rpc', '--log-cache-events', + '--port', str(self.port), + ] + if cli_vars: + handle_and_check_args.extend(['--vars', cli_vars]) + super(ServerProcess, self).__init__( + target=handle_and_check, + args=(handle_and_check_args,)) + + def is_up(self): + sock = socket.socket() + try: + sock.connect(('localhost', self.port)) + except socket.error: + return False + sock.close() + return True + + def start(self): + super(ServerProcess, self).start() + for _ in range(10): + if self.is_up(): + break + time.sleep(0.5) + if not self.is_up(): + self.terminate() + raise Exception('server never appeared!') + + +@unittest.skipIf(os.name=='nt', 'Windows not supported for now') +class TestRPCServer(BaseSourcesTest): + def setUp(self): + super(TestRPCServer, self).setUp() + self._server = ServerProcess( + cli_vars='{{test_run_schema: {}}}'.format(self.unique_schema()) + ) + self._server.start() + + def tearDown(self): + self._server.terminate() + super(TestRPCServer, self).tearDown() + + def build_query(self, method, kwargs, sql=None, test_request_id=1): + if sql is not None: + kwargs['sql'] = b64(sql.encode('utf-8')).decode('utf-8') + + return { + 'jsonrpc': '2.0', + 'method': method, + 'params': kwargs, + 'id': test_request_id + } + + def perform_query(self, query): + url = 'http://localhost:{}/jsonrpc'.format(self._server.port) + headers = {'content-type': 'application/json'} + response = requests.post(url, headers=headers, data=json.dumps(query)) + return response + + def query(self, _method, _sql=None, _test_request_id=1, **kwargs): + built = self.build_query(_method, kwargs, _sql, _test_request_id) + return self.perform_query(built) + + def assertResultHasTimings(self, result, *names): + self.assertIn('timing', result) + timings = result['timing'] + self.assertEqual(len(timings), len(names)) + for expected_name, timing in zip(names, timings): + self.assertIn('name', timing) + self.assertEqual(timing['name'], expected_name) + self.assertIn('started_at', timing) + self.assertIn('completed_at', timing) + datetime.strptime(timing['started_at'], '%Y-%m-%dT%H:%M:%S.%fZ') + datetime.strptime(timing['completed_at'], '%Y-%m-%dT%H:%M:%S.%fZ') + + def assertIsResult(self, data): + self.assertEqual(data['id'], 1) + self.assertEqual(data['jsonrpc'], '2.0') + self.assertIn('result', data) + self.assertNotIn('error', data) + return data['result'] + + def assertIsError(self, data): + self.assertEqual(data['id'], 1) + self.assertEqual(data['jsonrpc'], '2.0') + self.assertIn('error', data) + self.assertNotIn('result', data) + return data['error'] + + def assertIsErrorWithCode(self, data, code): + error = self.assertIsError(data) + self.assertIn('code', error) + self.assertIn('message', error) + self.assertEqual(error['code'], code) + return error + + def assertResultHasSql(self, data, raw_sql, compiled_sql=None): + if compiled_sql is None: + compiled_sql = raw_sql + result = self.assertIsResult(data) + self.assertIn('raw_sql', result) + self.assertIn('compiled_sql', result) + self.assertEqual(result['raw_sql'], raw_sql) + self.assertEqual(result['compiled_sql'], compiled_sql) + return result + + def assertSuccessfulCompilationResult(self, data, raw_sql, compiled_sql=None): + result = self.assertResultHasSql(data, raw_sql, compiled_sql) + self.assertNotIn('table', result) + # compile results still have an 'execute' timing, it just represents + # the time to construct a result object. + self.assertResultHasTimings(result, 'compile', 'execute') + + def assertSuccessfulRunResult(self, data, raw_sql, compiled_sql=None, table=None): + result = self.assertResultHasSql(data, raw_sql, compiled_sql) + self.assertIn('table', result) + if table is not None: + self.assertEqual(result['table'], table) + self.assertResultHasTimings(result, 'compile', 'execute') + + @use_profile('postgres') + def test_compile(self): + trivial = self.query( + 'compile', + 'select 1 as id', + name='foo' + ).json() + self.assertSuccessfulCompilationResult( + trivial, 'select 1 as id' + ) + + ref = self.query( + 'compile', + 'select * from {{ ref("descendant_model") }}', + name='foo' + ).json() + self.assertSuccessfulCompilationResult( + ref, + 'select * from {{ ref("descendant_model") }}', + compiled_sql='select * from "{}"."{}"."descendant_model"'.format( + self.default_database, + self.unique_schema()) + ) + + source = self.query( + 'compile', + 'select * from {{ source("test_source", "test_table") }}', + name='foo' + ).json() + + self.assertSuccessfulCompilationResult( + source, + 'select * from {{ source("test_source", "test_table") }}', + compiled_sql='select * from "{}"."{}"."source"'.format( + self.default_database, + self.unique_schema()) + ) + + @use_profile('postgres') + def test_run(self): + # seed + run dbt to make models before using them! + self.run_dbt_with_vars(['seed']) + self.run_dbt_with_vars(['run']) + data = self.query( + 'run', + 'select 1 as id', + name='foo' + ).json() + self.assertSuccessfulRunResult( + data, 'select 1 as id', table={'column_names': ['id'], 'rows': [[1.0]]} + ) + + ref = self.query( + 'run', + 'select * from {{ ref("descendant_model") }} order by updated_at limit 1', + name='foo' + ).json() + self.assertSuccessfulRunResult( + ref, + 'select * from {{ ref("descendant_model") }} order by updated_at limit 1', + compiled_sql='select * from "{}"."{}"."descendant_model" order by updated_at limit 1'.format( + self.default_database, + self.unique_schema()), + table={ + 'column_names': ['favorite_color', 'id', 'first_name', 'email', 'ip_address', 'updated_at'], + 'rows': [['blue', 38.0, 'Gary', 'gray11@statcounter.com', "'40.193.124.56'", '1970-01-27T10:04:51']], + } + ) + + source = self.query( + 'run', + 'select * from {{ source("test_source", "test_table") }} order by updated_at limit 1', + name='foo' + ).json() + + self.assertSuccessfulRunResult( + source, + 'select * from {{ source("test_source", "test_table") }} order by updated_at limit 1', + compiled_sql='select * from "{}"."{}"."source" order by updated_at limit 1'.format( + self.default_database, + self.unique_schema()), + table={ + 'column_names': ['favorite_color', 'id', 'first_name', 'email', 'ip_address', 'updated_at'], + 'rows': [['blue', 38.0, 'Gary', 'gray11@statcounter.com', "'40.193.124.56'", '1970-01-27T10:04:51']], + } + ) + + @use_profile('postgres') + def test_invalid_requests(self): + data = self.query( + 'xxxxxnotamethodxxxxx', + 'hi this is not sql' + ).json() + error = self.assertIsErrorWithCode(data, -32601) + self.assertEqual(error['message'], 'Method not found') + + data = self.query( + 'compile', + 'select * from {{ reff("nonsource_descendant") }}', + name='mymodel' + ).json() + error = self.assertIsErrorWithCode(data, -32000) + self.assertEqual(error['message'], 'Server error') + self.assertIn('data', error) + self.assertEqual(error['data']['type'], 'RPCException') + self.assertEqual( + error['data']['message'], + "Compilation Error in rpc mymodel (from remote system)\n 'reff' is undefined" + ) + + data = self.query( + 'run', + 'hi this is not sql', + name='foo' + ).json() + error = self.assertIsErrorWithCode(data, -32000) + self.assertEqual(error['message'], 'Server error') + self.assertIn('data', error) + self.assertEqual(error['data']['type'], 'RPCException') + self.assertEqual( + error['data']['message'], + 'Database Error in rpc foo (from remote system)\n syntax error at or near "hi"\n LINE 1: hi this is not sql\n ^' + ) + + @use_profile('postgres') + def test_timeout(self): + data = self.query( + 'run', + 'select from pg_sleep(5)', + name='foo', + timeout=1 + ).json() + error = self.assertIsErrorWithCode(data, -32000) + self.assertEqual(error['message'], 'Server error') + self.assertIn('data', error) + self.assertEqual(error['data']['type'], 'RPCException') + self.assertEqual(error['data']['message'], 'timed out after 1s') diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 1163468a535..ea6a9fd50f0 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -357,7 +357,6 @@ def from_args(self, project_profile_name='default', **kwargs): kw = { 'args': self.args, 'project_profile_name': project_profile_name, - 'cli_vars': {}, } kw.update(kwargs) return dbt.config.Profile.from_args(**kw) @@ -484,7 +483,7 @@ def test_cli_and_env_vars(self): self.args.target = 'cli-and-env-vars' self.args.vars = '{"cli_value_host": "cli-postgres-host"}' with mock.patch.dict(os.environ, self.env_override): - profile = self.from_args(cli_vars=None) + profile = self.from_args() from_raw = self.from_raw_profile_info( target_override='cli-and-env-vars', cli_vars={'cli_value_host': 'cli-postgres-host'},