diff --git a/dbt/adapters/default.py b/dbt/adapters/default.py index dd9e65b9bb5..40db59c6016 100644 --- a/dbt/adapters/default.py +++ b/dbt/adapters/default.py @@ -287,6 +287,17 @@ def get_connection(cls, profile, name=None, recache_if_missing=True): return cls.get_connection(profile, name) + @classmethod + def cancel_open_connections(cls, profile): + global connections_in_use + + for name, connection in connections_in_use.items(): + if name == 'master': + continue + + cls.cancel_connection(profile, connection) + yield name + @classmethod def total_connections_allocated(cls): global connections_in_use, connections_available diff --git a/dbt/adapters/postgres.py b/dbt/adapters/postgres.py index 01f567141b2..fdc21615f1d 100644 --- a/dbt/adapters/postgres.py +++ b/dbt/adapters/postgres.py @@ -158,3 +158,17 @@ def query_for_existing(cls, profile, schema, model_name=None): existing = [(name, relation_type) for (name, relation_type) in results] return dict(existing) + + @classmethod + def cancel_connection(cls, profile, connection): + connection_name = connection.get('name') + pid = connection.get('handle').get_backend_pid() + + sql = "select pg_terminate_backend({})".format(pid) + + logger.debug("Cancelling query '{}' ({})".format(connection_name, pid)) + + _, cursor = cls.add_query(profile, sql, 'master') + res = cursor.fetchone() + + logger.debug("Cancel query '{}': {}".format(connection_name, res)) diff --git a/dbt/adapters/snowflake.py b/dbt/adapters/snowflake.py index d3819434714..30e8b7e7525 100644 --- a/dbt/adapters/snowflake.py +++ b/dbt/adapters/snowflake.py @@ -182,3 +182,19 @@ def add_query(cls, profile, sql, model_name=None, auto_begin=True, profile, individual_query, model_name, auto_begin) return connection, cursor + + @classmethod + def cancel_connection(cls, profile, connection): + handle = connection['handle'] + sid = handle.session_id + + connection_name = connection.get('name') + + sql = 'select system$abort_session({})'.format(sid) + + logger.debug("Cancelling query '{}' ({})".format(connection_name, sid)) + + _, cursor = cls.add_query(profile, sql, 'master') + res = cursor.fetchone() + + logger.debug("Cancel query '{}': {}".format(connection_name, res)) diff --git a/dbt/main.py b/dbt/main.py index 3238638461b..7becfde000c 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -29,6 +29,10 @@ def main(args=None): try: handle(args) + except KeyboardInterrupt as e: + logger.info("ctrl-c") + sys.exit(1) + except RuntimeError as e: logger.info("Encountered an error:") logger.info(str(e)) diff --git a/dbt/runner.py b/dbt/runner.py index 156e0e5ec31..6c2cbb6f262 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -352,7 +352,8 @@ def compile_node(self, node, flat_graph): return node def safe_compile_node(self, data): - node, flat_graph, existing, schema_name, node_index, num_nodes = data + node = data['node'] + flat_graph = data['flat_graph'] result = RunModelResult(node) profile = self.project.run_environment() @@ -368,7 +369,12 @@ def safe_compile_node(self, data): return result def safe_execute_node(self, data): - node, flat_graph, existing, schema_name, node_index, num_nodes = data + node = data['node'] + flat_graph = data['flat_graph'] + existing = data['existing'] + schema_name = data['schema_name'] + node_index = data['node_index'] + num_nodes = data['num_nodes'] start_time = time.time() @@ -573,24 +579,46 @@ def get_idx(node): else: action = self.safe_compile_node - for result in pool.imap_unordered( - action, - [(node, flat_graph, existing, schema_name, - get_idx(node), num_nodes,) - for node in nodes_to_execute]): - - node_results.append(result) - - # propagate so that CTEs get injected properly - flat_graph['nodes'][result.node.get('unique_id')] = result.node - - index = get_idx(result.node) - if should_execute: - track_model_run(index, num_nodes, result) - - if result.errored: - on_failure(result.node) - logger.info(result.error) + node_result = [] + try: + args_list = [] + for node in nodes_to_execute: + args_list.append({ + 'node': node, + 'flat_graph': flat_graph, + 'existing': existing, + 'schema_name': schema_name, + 'node_index': get_idx(node), + 'num_nodes': num_nodes + }) + + for result in pool.imap_unordered(action, args_list): + node_results.append(result) + + # propagate so that CTEs get injected properly + node_id = result.node.get('unique_id') + flat_graph['nodes'][node_id] = result.node + + index = get_idx(result.node) + if should_execute: + track_model_run(index, num_nodes, result) + + if result.errored: + on_failure(result.node) + logger.info(result.error) + + except KeyboardInterrupt: + pool.close() + pool.terminate() + + profile = self.project.run_environment() + adapter = get_adapter(profile) + + for conn_name in adapter.cancel_open_connections(profile): + dbt.ui.printer.print_cancel_line(conn_name, schema_name) + + pool.join() + raise pool.close() pool.join() diff --git a/dbt/task/run.py b/dbt/task/run.py index 3b43fab7d37..3b85c3afa22 100644 --- a/dbt/task/run.py +++ b/dbt/task/run.py @@ -17,4 +17,5 @@ def run(self): results = runner.run_models(self.args.models, self.args.exclude) - logger.info(dbt.ui.printer.get_run_status_line(results)) + if results: + logger.info(dbt.ui.printer.get_run_status_line(results)) diff --git a/dbt/ui/printer.py b/dbt/ui/printer.py index af088a7d881..456a4dd9455 100644 --- a/dbt/ui/printer.py +++ b/dbt/ui/printer.py @@ -46,10 +46,13 @@ def print_timestamped_line(msg): def print_fancy_output_line(msg, status, index, total, execution_time=None): - prefix = "{timestamp} | {index} of {total} {message}".format( + if index is None or total is None: + progress = '' + else: + progress = '{} of {} '.format(index, total) + prefix = "{timestamp} | {progress}{message}".format( timestamp=get_timestamp(), - index=index, - total=total, + progress=progress, message=msg) justified = prefix.ljust(80, ".") @@ -73,6 +76,11 @@ def print_skip_line(model, schema, relation, index, num_models): print_fancy_output_line(msg, yellow('SKIP'), index, num_models) +def print_cancel_line(model, schema): + msg = 'CANCEL query {}.{}'.format(schema, model) + print_fancy_output_line(msg, red('CANCEL'), index=None, total=None) + + def get_counts(flat_nodes): counts = {}