From 19a4882d7e09b75a3c5b08569fb9e4d2bbadd483 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Wed, 11 Mar 2020 09:28:56 +0000 Subject: [PATCH] Added --fail-fast argument for dbt run and dbt test --- CHANGELOG.md | 15 +++++++ core/dbt/exceptions.py | 11 +++++ core/dbt/linker.py | 8 ++++ core/dbt/main.py | 16 +++++++ core/dbt/task/runnable.py | 87 +++++++++++++++++++++++++-------------- core/dbt/ui/printer.py | 10 +++-- 6 files changed, 112 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d2b612175e..f77ac91f1a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +## dbt next (release TBD) +### Features +- Added --fail-fast argument for dbt run and dbt test to fail on first test failure or runtime error. ([#1649](https://github.com/fishtown-analytics/dbt/issues/1649)) + +### Fixes +- When a jinja value is undefined, give a helpful error instead of failing with cryptic "cannot pickle ParserMacroCapture" errors ([#2110](https://github.com/fishtown-analytics/dbt/issues/2110), [#2184](https://github.com/fishtown-analytics/dbt/pull/2184)) + +Contributors: + - [@raalsky](https://github.com/Raalsky) ([#2224](https://github.com/fishtown-analytics/dbt/pull/2224)) + +## dbt 0.16.0rc2 (March 4, 2020) + +### Under the hood +- Pin cffi to <1.14 to avoid a version conflict with snowflake-connector-python ([#2180](https://github.com/fishtown-analytics/dbt/issues/2180), [#2181](https://github.com/fishtown-analytics/dbt/pull/2181)) + ## dbt 0.16.0rc1 (March 4, 2020) ### Breaking changes diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index cbbe3c74752..894d50758b9 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -281,6 +281,17 @@ def __init__(self, message, project=None, result_type='invalid_project'): self.result_type = result_type +class FailFastException(Exception): + CODE = 10013 + MESSAGE = 'FailFast Error' + + def __init__(self, result=None): + self.result = result + + def __str__(self): + return 'Falling early due to test failure or runtime error' + + class DbtProjectError(DbtConfigError): pass diff --git a/core/dbt/linker.py b/core/dbt/linker.py index 991d704901d..1170806b11b 100644 --- a/core/dbt/linker.py +++ b/core/dbt/linker.py @@ -48,6 +48,8 @@ def __init__(self, graph, manifest): self._scores = self._calculate_scores() # populate the initial queue self._find_new_additions() + # awaits after task end + self.some_task_done = threading.Condition(self.lock) def _include_in_cost(self, node_id): node = self.manifest.expect(node_id) @@ -153,6 +155,7 @@ def mark_done(self, node_id): self.graph.remove_node(node_id) self._find_new_additions() self.inner.task_done() + self.some_task_done.notify_all() def _mark_in_progress(self, node_id): """Mark the node as 'in progress'. @@ -171,6 +174,11 @@ def join(self): """ self.inner.join() + def wait_until_something_was_done(self): + with self.some_task_done: + self.some_task_done.wait() + return self.inner.unfinished_tasks + class Linker: def __init__(self, data=None): diff --git a/core/dbt/main.py b/core/dbt/main.py index 12fa3bb2301..0da30df2695 100644 --- a/core/dbt/main.py +++ b/core/dbt/main.py @@ -398,6 +398,14 @@ def _build_run_subparser(subparsers, base_subparser): help=''' Compile SQL and execute against the current target database. ''') + run_sub.add_argument( + '-x', + '--fail-fast', + action='store_true', + help=''' + Stop execution upon a first failure. + ''' + ) run_sub.set_defaults(cls=run_task.RunTask, which='run', rpc_method='run') return run_sub @@ -555,6 +563,14 @@ def _build_test_subparser(subparsers, base_subparser): Run constraint validations from schema.yml files ''' ) + sub.add_argument( + '-x', + '--fail-fast', + action='store_true', + help=''' + Stop execution upon a first test failure. + ''' + ) sub.set_defaults(cls=test_task.TestTask, which='test', rpc_method='test') return sub diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 33e6a78ff58..48fac3d43f4 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -24,7 +24,10 @@ from dbt.contracts.graph.manifest import Manifest from dbt.contracts.results import ExecutionResult from dbt.exceptions import ( - InternalException, NotImplementedException, RuntimeException + InternalException, + NotImplementedException, + RuntimeException, + FailFastException ) from dbt.linker import Linker, GraphQueue from dbt.perf_utils import get_full_manifest @@ -160,11 +163,16 @@ def call_runner(self, runner): with finishctx, DbtModelState(status): logger.debug('Finished running node {}'.format( runner.node.unique_id)) - if result.error is not None and self.raise_on_first_error(): + + fail_fast = getattr(self.config.args, 'fail_fast', False) + + if (result.fail is not None or result.error is not None) and fail_fast: + self._raise_next_tick = FailFastException(result) + elif result.error is not None and self.raise_on_first_error(): # if we raise inside a thread, it'll just get silently swallowed. # stash the error message we want here, and it will check the # next 'tick' - should be soon since our thread is about to finish! - self._raise_next_tick = result.error + self._raise_next_tick = RuntimeException(result.error) return result @@ -184,7 +192,7 @@ def _submit(self, pool, args, callback): def _raise_set_error(self): if self._raise_next_tick is not None: - raise RuntimeException(self._raise_next_tick) + raise self._raise_next_tick def run_queue(self, pool): """Given a pool, submit jobs from the queue to the pool. @@ -219,7 +227,15 @@ def callback(result): self._submit(pool, args, callback) # block on completion - self.job_queue.join() + if getattr(self.config.args, 'fail_fast', False): + # checkout for an errors after task completion in case of + # fast failure + while self.job_queue.wait_until_something_was_done(): + self._raise_set_error() + else: + # wait until every task will be complete + self.job_queue.join() + # if an error got set during join(), raise it. self._raise_set_error() @@ -248,6 +264,34 @@ def _handle_result(self, result): cause = None self._mark_dependent_errors(node.unique_id, result, cause) + def _cancel_connections(self, pool): + """Given a pool, cancel all adapter connections and wait until all + runners gentle terminates. + """ + pool.close() + pool.terminate() + + adapter = get_adapter(self.config) + + if not adapter.is_cancelable(): + msg = ("The {} adapter does not support query " + "cancellation. Some queries may still be " + "running!".format(adapter.type())) + + yellow = dbt.ui.printer.COLOR_FG_YELLOW + dbt.ui.printer.print_timestamped_line(msg, yellow) + raise + + for conn_name in adapter.cancel_open_connections(): + if self.manifest is not None: + node = self.manifest.nodes.get(conn_name) + if node is not None and node.is_ephemeral_model: + continue + # if we don't have a manifest/don't have a node, print anyway. + dbt.ui.printer.print_cancel_line(conn_name) + + pool.join() + def execute_nodes(self): num_threads = self.config.threads target_name = self.config.target_name @@ -263,34 +307,15 @@ def execute_nodes(self): try: self.run_queue(pool) - except KeyboardInterrupt: - pool.close() - pool.terminate() - - adapter = get_adapter(self.config) - - if not adapter.is_cancelable(): - msg = ("The {} adapter does not support query " - "cancellation. Some queries may still be " - "running!".format(adapter.type())) - - yellow = dbt.ui.printer.COLOR_FG_YELLOW - dbt.ui.printer.print_timestamped_line(msg, yellow) - raise - - for conn_name in adapter.cancel_open_connections(): - if self.manifest is not None: - node = self.manifest.nodes.get(conn_name) - if node is not None and node.is_ephemeral_model: - continue - # if we don't have a manifest/don't have a node, print anyway. - dbt.ui.printer.print_cancel_line(conn_name) - - pool.join() + except FailFastException as failure: + self._cancel_connections(pool) + dbt.ui.printer.print_run_result_error(failure.result) + raise + except KeyboardInterrupt: + self._cancel_connections(pool) dbt.ui.printer.print_run_end_messages(self.node_results, - early_exit=True) - + keyboard_interrupt=True) raise pool.close() diff --git a/core/dbt/ui/printer.py b/core/dbt/ui/printer.py index c913274b611..1c6ff051c79 100644 --- a/core/dbt/ui/printer.py +++ b/core/dbt/ui/printer.py @@ -348,11 +348,11 @@ def print_skip_caused_by_error( def print_end_of_run_summary( - num_errors: int, num_warnings: int, early_exit: bool = False + num_errors: int, num_warnings: int, keyboard_interrupt: bool = False ) -> None: error_plural = dbt.utils.pluralize(num_errors, 'error') warn_plural = dbt.utils.pluralize(num_warnings, 'warning') - if early_exit: + if keyboard_interrupt: message = yellow('Exited because of keyboard interrupt.') elif num_errors > 0: message = red("Completed with {} and {}:".format( @@ -367,11 +367,13 @@ def print_end_of_run_summary( logger.info('{}'.format(message)) -def print_run_end_messages(results, early_exit: bool = False) -> None: +def print_run_end_messages(results, keyboard_interrupt: bool = False) -> None: errors = [r for r in results if r.error is not None or r.fail] warnings = [r for r in results if r.warn] with DbtStatusMessage(), InvocationProcessor(): - print_end_of_run_summary(len(errors), len(warnings), early_exit) + print_end_of_run_summary(len(errors), + len(warnings), + keyboard_interrupt) for error in errors: print_run_result_error(error, is_warning=False)