Skip to content

Commit

Permalink
Added --fail-fast argument for dbt run and dbt test
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky committed Mar 20, 2020
1 parent 3925c08 commit 19a4882
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 35 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
## dbt next (release TBD)
### Features
- Added --fail-fast argument for dbt run and dbt test to fail on first test failure or runtime error. ([#1649](https://github.com/fishtown-analytics/dbt/issues/1649))

### Fixes
- When a jinja value is undefined, give a helpful error instead of failing with cryptic "cannot pickle ParserMacroCapture" errors ([#2110](https://github.com/fishtown-analytics/dbt/issues/2110), [#2184](https://github.com/fishtown-analytics/dbt/pull/2184))

Contributors:
- [@raalsky](https://github.com/Raalsky) ([#2224](https://github.com/fishtown-analytics/dbt/pull/2224))

## dbt 0.16.0rc2 (March 4, 2020)

### Under the hood
- Pin cffi to <1.14 to avoid a version conflict with snowflake-connector-python ([#2180](https://github.com/fishtown-analytics/dbt/issues/2180), [#2181](https://github.com/fishtown-analytics/dbt/pull/2181))

## dbt 0.16.0rc1 (March 4, 2020)

### Breaking changes
Expand Down
11 changes: 11 additions & 0 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,17 @@ def __init__(self, message, project=None, result_type='invalid_project'):
self.result_type = result_type


class FailFastException(Exception):
CODE = 10013
MESSAGE = 'FailFast Error'

def __init__(self, result=None):
self.result = result

def __str__(self):
return 'Falling early due to test failure or runtime error'


class DbtProjectError(DbtConfigError):
pass

Expand Down
8 changes: 8 additions & 0 deletions core/dbt/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(self, graph, manifest):
self._scores = self._calculate_scores()
# populate the initial queue
self._find_new_additions()
# awaits after task end
self.some_task_done = threading.Condition(self.lock)

def _include_in_cost(self, node_id):
node = self.manifest.expect(node_id)
Expand Down Expand Up @@ -153,6 +155,7 @@ def mark_done(self, node_id):
self.graph.remove_node(node_id)
self._find_new_additions()
self.inner.task_done()
self.some_task_done.notify_all()

def _mark_in_progress(self, node_id):
"""Mark the node as 'in progress'.
Expand All @@ -171,6 +174,11 @@ def join(self):
"""
self.inner.join()

def wait_until_something_was_done(self):
with self.some_task_done:
self.some_task_done.wait()
return self.inner.unfinished_tasks


class Linker:
def __init__(self, data=None):
Expand Down
16 changes: 16 additions & 0 deletions core/dbt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,14 @@ def _build_run_subparser(subparsers, base_subparser):
help='''
Compile SQL and execute against the current target database.
''')
run_sub.add_argument(
'-x',
'--fail-fast',
action='store_true',
help='''
Stop execution upon a first failure.
'''
)
run_sub.set_defaults(cls=run_task.RunTask, which='run', rpc_method='run')
return run_sub

Expand Down Expand Up @@ -555,6 +563,14 @@ def _build_test_subparser(subparsers, base_subparser):
Run constraint validations from schema.yml files
'''
)
sub.add_argument(
'-x',
'--fail-fast',
action='store_true',
help='''
Stop execution upon a first test failure.
'''
)

sub.set_defaults(cls=test_task.TestTask, which='test', rpc_method='test')
return sub
Expand Down
87 changes: 56 additions & 31 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import ExecutionResult
from dbt.exceptions import (
InternalException, NotImplementedException, RuntimeException
InternalException,
NotImplementedException,
RuntimeException,
FailFastException
)
from dbt.linker import Linker, GraphQueue
from dbt.perf_utils import get_full_manifest
Expand Down Expand Up @@ -160,11 +163,16 @@ def call_runner(self, runner):
with finishctx, DbtModelState(status):
logger.debug('Finished running node {}'.format(
runner.node.unique_id))
if result.error is not None and self.raise_on_first_error():

fail_fast = getattr(self.config.args, 'fail_fast', False)

if (result.fail is not None or result.error is not None) and fail_fast:
self._raise_next_tick = FailFastException(result)
elif result.error is not None and self.raise_on_first_error():
# if we raise inside a thread, it'll just get silently swallowed.
# stash the error message we want here, and it will check the
# next 'tick' - should be soon since our thread is about to finish!
self._raise_next_tick = result.error
self._raise_next_tick = RuntimeException(result.error)

return result

Expand All @@ -184,7 +192,7 @@ def _submit(self, pool, args, callback):

def _raise_set_error(self):
if self._raise_next_tick is not None:
raise RuntimeException(self._raise_next_tick)
raise self._raise_next_tick

def run_queue(self, pool):
"""Given a pool, submit jobs from the queue to the pool.
Expand Down Expand Up @@ -219,7 +227,15 @@ def callback(result):
self._submit(pool, args, callback)

# block on completion
self.job_queue.join()
if getattr(self.config.args, 'fail_fast', False):
# checkout for an errors after task completion in case of
# fast failure
while self.job_queue.wait_until_something_was_done():
self._raise_set_error()
else:
# wait until every task will be complete
self.job_queue.join()

# if an error got set during join(), raise it.
self._raise_set_error()

Expand Down Expand Up @@ -248,6 +264,34 @@ def _handle_result(self, result):
cause = None
self._mark_dependent_errors(node.unique_id, result, cause)

def _cancel_connections(self, pool):
"""Given a pool, cancel all adapter connections and wait until all
runners gentle terminates.
"""
pool.close()
pool.terminate()

adapter = get_adapter(self.config)

if not adapter.is_cancelable():
msg = ("The {} adapter does not support query "
"cancellation. Some queries may still be "
"running!".format(adapter.type()))

yellow = dbt.ui.printer.COLOR_FG_YELLOW
dbt.ui.printer.print_timestamped_line(msg, yellow)
raise

for conn_name in adapter.cancel_open_connections():
if self.manifest is not None:
node = self.manifest.nodes.get(conn_name)
if node is not None and node.is_ephemeral_model:
continue
# if we don't have a manifest/don't have a node, print anyway.
dbt.ui.printer.print_cancel_line(conn_name)

pool.join()

def execute_nodes(self):
num_threads = self.config.threads
target_name = self.config.target_name
Expand All @@ -263,34 +307,15 @@ def execute_nodes(self):
try:
self.run_queue(pool)

except KeyboardInterrupt:
pool.close()
pool.terminate()

adapter = get_adapter(self.config)

if not adapter.is_cancelable():
msg = ("The {} adapter does not support query "
"cancellation. Some queries may still be "
"running!".format(adapter.type()))

yellow = dbt.ui.printer.COLOR_FG_YELLOW
dbt.ui.printer.print_timestamped_line(msg, yellow)
raise

for conn_name in adapter.cancel_open_connections():
if self.manifest is not None:
node = self.manifest.nodes.get(conn_name)
if node is not None and node.is_ephemeral_model:
continue
# if we don't have a manifest/don't have a node, print anyway.
dbt.ui.printer.print_cancel_line(conn_name)

pool.join()
except FailFastException as failure:
self._cancel_connections(pool)
dbt.ui.printer.print_run_result_error(failure.result)
raise

except KeyboardInterrupt:
self._cancel_connections(pool)
dbt.ui.printer.print_run_end_messages(self.node_results,
early_exit=True)

keyboard_interrupt=True)
raise

pool.close()
Expand Down
10 changes: 6 additions & 4 deletions core/dbt/ui/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,11 @@ def print_skip_caused_by_error(


def print_end_of_run_summary(
num_errors: int, num_warnings: int, early_exit: bool = False
num_errors: int, num_warnings: int, keyboard_interrupt: bool = False
) -> None:
error_plural = dbt.utils.pluralize(num_errors, 'error')
warn_plural = dbt.utils.pluralize(num_warnings, 'warning')
if early_exit:
if keyboard_interrupt:
message = yellow('Exited because of keyboard interrupt.')
elif num_errors > 0:
message = red("Completed with {} and {}:".format(
Expand All @@ -367,11 +367,13 @@ def print_end_of_run_summary(
logger.info('{}'.format(message))


def print_run_end_messages(results, early_exit: bool = False) -> None:
def print_run_end_messages(results, keyboard_interrupt: bool = False) -> None:
errors = [r for r in results if r.error is not None or r.fail]
warnings = [r for r in results if r.warn]
with DbtStatusMessage(), InvocationProcessor():
print_end_of_run_summary(len(errors), len(warnings), early_exit)
print_end_of_run_summary(len(errors),
len(warnings),
keyboard_interrupt)

for error in errors:
print_run_result_error(error, is_warning=False)
Expand Down

0 comments on commit 19a4882

Please sign in to comment.