Skip to content

Commit

Permalink
Cancel running queries on ctrl-c (#444)
Browse files Browse the repository at this point in the history
* wip

* better ctrl-c handling

* working for snowflake

* better cancel output

* print correct model name for cancel lines

* simlify (kind of) keyboardinterrupt logic

* code cleanup

* pep8

* remove debug code
  • Loading branch information
drewbanin authored May 24, 2017
1 parent dfb24fd commit a225a56
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 24 deletions.
11 changes: 11 additions & 0 deletions dbt/adapters/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,17 @@ def get_connection(cls, profile, name=None, recache_if_missing=True):

return cls.get_connection(profile, name)

@classmethod
def cancel_open_connections(cls, profile):
global connections_in_use

for name, connection in connections_in_use.items():
if name == 'master':
continue

cls.cancel_connection(profile, connection)
yield name

@classmethod
def total_connections_allocated(cls):
global connections_in_use, connections_available
Expand Down
14 changes: 14 additions & 0 deletions dbt/adapters/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,17 @@ def query_for_existing(cls, profile, schema, model_name=None):
existing = [(name, relation_type) for (name, relation_type) in results]

return dict(existing)

@classmethod
def cancel_connection(cls, profile, connection):
connection_name = connection.get('name')
pid = connection.get('handle').get_backend_pid()

sql = "select pg_terminate_backend({})".format(pid)

logger.debug("Cancelling query '{}' ({})".format(connection_name, pid))

_, cursor = cls.add_query(profile, sql, 'master')
res = cursor.fetchone()

logger.debug("Cancel query '{}': {}".format(connection_name, res))
16 changes: 16 additions & 0 deletions dbt/adapters/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,19 @@ def add_query(cls, profile, sql, model_name=None, auto_begin=True,
profile, individual_query, model_name, auto_begin)

return connection, cursor

@classmethod
def cancel_connection(cls, profile, connection):
handle = connection['handle']
sid = handle.session_id

connection_name = connection.get('name')

sql = 'select system$abort_session({})'.format(sid)

logger.debug("Cancelling query '{}' ({})".format(connection_name, sid))

_, cursor = cls.add_query(profile, sql, 'master')
res = cursor.fetchone()

logger.debug("Cancel query '{}': {}".format(connection_name, res))
4 changes: 4 additions & 0 deletions dbt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def main(args=None):
try:
handle(args)

except KeyboardInterrupt as e:
logger.info("ctrl-c")
sys.exit(1)

except RuntimeError as e:
logger.info("Encountered an error:")
logger.info(str(e))
Expand Down
68 changes: 48 additions & 20 deletions dbt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ def compile_node(self, node, flat_graph):
return node

def safe_compile_node(self, data):
node, flat_graph, existing, schema_name, node_index, num_nodes = data
node = data['node']
flat_graph = data['flat_graph']

result = RunModelResult(node)
profile = self.project.run_environment()
Expand All @@ -368,7 +369,12 @@ def safe_compile_node(self, data):
return result

def safe_execute_node(self, data):
node, flat_graph, existing, schema_name, node_index, num_nodes = data
node = data['node']
flat_graph = data['flat_graph']
existing = data['existing']
schema_name = data['schema_name']
node_index = data['node_index']
num_nodes = data['num_nodes']

start_time = time.time()

Expand Down Expand Up @@ -573,24 +579,46 @@ def get_idx(node):
else:
action = self.safe_compile_node

for result in pool.imap_unordered(
action,
[(node, flat_graph, existing, schema_name,
get_idx(node), num_nodes,)
for node in nodes_to_execute]):

node_results.append(result)

# propagate so that CTEs get injected properly
flat_graph['nodes'][result.node.get('unique_id')] = result.node

index = get_idx(result.node)
if should_execute:
track_model_run(index, num_nodes, result)

if result.errored:
on_failure(result.node)
logger.info(result.error)
node_result = []
try:
args_list = []
for node in nodes_to_execute:
args_list.append({
'node': node,
'flat_graph': flat_graph,
'existing': existing,
'schema_name': schema_name,
'node_index': get_idx(node),
'num_nodes': num_nodes
})

for result in pool.imap_unordered(action, args_list):
node_results.append(result)

# propagate so that CTEs get injected properly
node_id = result.node.get('unique_id')
flat_graph['nodes'][node_id] = result.node

index = get_idx(result.node)
if should_execute:
track_model_run(index, num_nodes, result)

if result.errored:
on_failure(result.node)
logger.info(result.error)

except KeyboardInterrupt:
pool.close()
pool.terminate()

profile = self.project.run_environment()
adapter = get_adapter(profile)

for conn_name in adapter.cancel_open_connections(profile):
dbt.ui.printer.print_cancel_line(conn_name, schema_name)

pool.join()
raise

pool.close()
pool.join()
Expand Down
3 changes: 2 additions & 1 deletion dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ def run(self):

results = runner.run_models(self.args.models, self.args.exclude)

logger.info(dbt.ui.printer.get_run_status_line(results))
if results:
logger.info(dbt.ui.printer.get_run_status_line(results))
14 changes: 11 additions & 3 deletions dbt/ui/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,13 @@ def print_timestamped_line(msg):


def print_fancy_output_line(msg, status, index, total, execution_time=None):
prefix = "{timestamp} | {index} of {total} {message}".format(
if index is None or total is None:
progress = ''
else:
progress = '{} of {} '.format(index, total)
prefix = "{timestamp} | {progress}{message}".format(
timestamp=get_timestamp(),
index=index,
total=total,
progress=progress,
message=msg)

justified = prefix.ljust(80, ".")
Expand All @@ -73,6 +76,11 @@ def print_skip_line(model, schema, relation, index, num_models):
print_fancy_output_line(msg, yellow('SKIP'), index, num_models)


def print_cancel_line(model, schema):
msg = 'CANCEL query {}.{}'.format(schema, model)
print_fancy_output_line(msg, red('CANCEL'), index=None, total=None)


def get_counts(flat_nodes):
counts = {}

Expand Down

0 comments on commit a225a56

Please sign in to comment.