From fc2b86df4fdd63dbe4d0aa842831b3257a85680a Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 26 Mar 2019 10:36:58 -0600 Subject: [PATCH 1/9] rearrange some path handling in the rpc server task --- core/dbt/task/rpc_server.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/core/dbt/task/rpc_server.py b/core/dbt/task/rpc_server.py index d7a31d374cb..084a3424371 100644 --- a/core/dbt/task/rpc_server.py +++ b/core/dbt/task/rpc_server.py @@ -4,8 +4,10 @@ from jsonrpc import Dispatcher, JSONRPCResponseManager +from werkzeug.wsgi import DispatcherMiddleware from werkzeug.wrappers import Request, Response from werkzeug.serving import run_simple +from werkzeug.exceptions import NotFound from dbt.logger import RPC_LOGGER as logger, add_queue_handler from dbt.task.base import ConfiguredTask @@ -161,14 +163,18 @@ def run(self): ) logger.info( - 'Send requests to http://{}:{}'.format(display_host, port) + 'Send requests to http://{}:{}/jsonrpc'.format(display_host, port) ) - run_simple(host, port, self.handle_request, - processes=self.config.threads) + app = self.handle_request + app = DispatcherMiddleware(app, { + '/jsonrpc': self.handle_jsonrpc_request, + }) + + run_simple(host, port, app, processes=self.config.threads) @Request.application - def handle_request(self, request): + def handle_jsonrpc_request(self, request): msg = 'Received request ({0}) from {0.remote_addr}, data={0.data}' logger.info(msg.format(request)) response = JSONRPCResponseManager.handle(request.data, self.dispatcher) @@ -182,3 +188,7 @@ def handle_request(self, request): response, request.remote_addr, json.loads(json_data)) ) return response + + @Request.application + def handle_request(self, request): + raise NotFound() From f2a0d36b34869c8383d1d45e3790da145fabd300 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Wed, 27 Mar 2019 12:44:04 -0600 Subject: [PATCH 2/9] when a dbt RuntimeException is raised inside the exception handler, re-raise it instead of wrapping it --- plugins/bigquery/dbt/adapters/bigquery/connections.py | 5 +++++ plugins/postgres/dbt/adapters/postgres/connections.py | 6 ++++++ plugins/snowflake/dbt/adapters/snowflake/connections.py | 5 +++++ 3 files changed, 16 insertions(+) diff --git a/plugins/bigquery/dbt/adapters/bigquery/connections.py b/plugins/bigquery/dbt/adapters/bigquery/connections.py index 84ed923d257..5f8a125ce2f 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/connections.py +++ b/plugins/bigquery/dbt/adapters/bigquery/connections.py @@ -94,6 +94,11 @@ def exception_handler(self, sql): except Exception as e: logger.debug("Unhandled error while running:\n{}".format(sql)) logger.debug(e) + if isinstance(e, dbt.exceptions.RuntimeException): + # during a sql query, an internal to dbt exception was raised. + # this sounds a lot like a signal handler and probably has + # useful information, so raise it without modification. + raise raise dbt.exceptions.RuntimeException(dbt.compat.to_string(e)) def cancel_open(self): diff --git a/plugins/postgres/dbt/adapters/postgres/connections.py b/plugins/postgres/dbt/adapters/postgres/connections.py index 6ba185ada92..374b86e2604 100644 --- a/plugins/postgres/dbt/adapters/postgres/connections.py +++ b/plugins/postgres/dbt/adapters/postgres/connections.py @@ -82,6 +82,12 @@ def exception_handler(self, sql): logger.debug("Error running SQL: %s", sql) logger.debug("Rolling back transaction.") self.release() + if isinstance(e, dbt.exceptions.RuntimeException): + # during a sql query, an internal to dbt exception was raised. + # this sounds a lot like a signal handler and probably has + # useful information, so raise it without modification. + raise + raise dbt.exceptions.RuntimeException(e) @classmethod diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index a6b9a67e2a7..b20c6f74fb1 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -97,6 +97,11 @@ def exception_handler(self, sql): logger.debug("Error running SQL: %s", sql) logger.debug("Rolling back transaction.") self.release() + if isinstance(e, dbt.exceptions.RuntimeException): + # during a sql query, an internal to dbt exception was raised. + # this sounds a lot like a signal handler and probably has + # useful information, so raise it without modification. + raise raise dbt.exceptions.RuntimeException(e.msg) @classmethod From ec1f4bc33d48d78263572da769a0839beca940f8 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Wed, 27 Mar 2019 14:03:59 -0600 Subject: [PATCH 3/9] fix bad add_query calls also fix unit tests --- plugins/postgres/dbt/adapters/postgres/connections.py | 2 +- plugins/snowflake/dbt/adapters/snowflake/connections.py | 2 +- test/unit/test_postgres_adapter.py | 2 +- test/unit/test_redshift_adapter.py | 2 +- test/unit/test_snowflake_adapter.py | 3 +-- 5 files changed, 5 insertions(+), 6 deletions(-) diff --git a/plugins/postgres/dbt/adapters/postgres/connections.py b/plugins/postgres/dbt/adapters/postgres/connections.py index 374b86e2604..360a130a936 100644 --- a/plugins/postgres/dbt/adapters/postgres/connections.py +++ b/plugins/postgres/dbt/adapters/postgres/connections.py @@ -137,7 +137,7 @@ def cancel(self, connection): logger.debug("Cancelling query '{}' ({})".format(connection_name, pid)) - _, cursor = self.add_query(sql, 'master') + _, cursor = self.add_query(sql) res = cursor.fetchone() logger.debug("Cancel query '{}': {}".format(connection_name, res)) diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index b20c6f74fb1..e29a404d7e2 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -156,7 +156,7 @@ def cancel(self, connection): logger.debug("Cancelling query '{}' ({})".format(connection_name, sid)) - _, cursor = self.add_query(sql, 'master') + _, cursor = self.add_query(sql) res = cursor.fetchone() logger.debug("Cancel query '{}': {}".format(connection_name, res)) diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index 745f101c46c..1998ec59a71 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -92,7 +92,7 @@ def test_cancel_open_connections_single(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1) - add_query.assert_called_once_with('select pg_terminate_backend(42)', 'master') + add_query.assert_called_once_with('select pg_terminate_backend(42)') master.handle.get_backend_pid.assert_not_called() diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py index 8d6a5184751..63d9dec822b 100644 --- a/test/unit/test_redshift_adapter.py +++ b/test/unit/test_redshift_adapter.py @@ -126,7 +126,7 @@ def test_cancel_open_connections_single(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1) - add_query.assert_called_once_with('select pg_terminate_backend(42)', 'master') + add_query.assert_called_once_with('select pg_terminate_backend(42)') master.handle.get_backend_pid.assert_not_called() diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index 1e55e09e7b0..caba79ea2e2 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -155,8 +155,7 @@ def test_cancel_open_connections_single(self): self.assertEqual( len(list(self.adapter.cancel_open_connections())), 1) - add_query.assert_called_once_with( - 'select system$abort_session(42)', 'master') + add_query.assert_called_once_with('select system$abort_session(42)') def test_client_session_keep_alive_false_by_default(self): self.adapter.connections.set_connection_name(name='new_connection_with_new_config') From 3f9b9962c3d7682844eb4eac1e04a6b61a5a5af3 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 26 Mar 2019 12:41:45 -0600 Subject: [PATCH 4/9] add "ps" and "kill" commands, and track tasks in flight proper cancel support Refactor rpc server logic a bit fix an issue in query cancelation where we would cancel ourselves fix exception handling misbehavior --- core/dbt/adapters/sql/connections.py | 3 +- core/dbt/exceptions.py | 16 ++ core/dbt/rpc.py | 324 ++++++++++++++++++++++++++- core/dbt/task/compile.py | 46 +++- core/dbt/task/rpc_server.py | 145 ++---------- core/dbt/task/runnable.py | 1 - 6 files changed, 400 insertions(+), 135 deletions(-) diff --git a/core/dbt/adapters/sql/connections.py b/core/dbt/adapters/sql/connections.py index 13bd312e876..a6db10d1215 100644 --- a/core/dbt/adapters/sql/connections.py +++ b/core/dbt/adapters/sql/connections.py @@ -30,9 +30,10 @@ def cancel(self, connection): def cancel_open(self): names = [] + this_connection = self.get_if_exists() with self.lock: for connection in self.thread_connections.values(): - if connection.name == 'master': + if connection is this_connection: continue self.cancel(connection) diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 22462bd5827..c96e29e6fd8 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -129,6 +129,22 @@ def data(self): return result +class RPCKilledException(RuntimeException): + CODE = 10009 + MESSAGE = 'RPC process killed' + + def __init__(self, signum): + self.signum = signum + self.message = 'RPC process killed by signal {}'.format(self.signum) + super(RPCKilledException, self).__init__(self.message) + + def data(self): + return { + 'signum': self.signum, + 'message': self.message, + } + + class DatabaseException(RuntimeException): CODE = 10003 MESSAGE = "Database Error" diff --git a/core/dbt/rpc.py b/core/dbt/rpc.py index f20f2ab32fb..979b7ecb76c 100644 --- a/core/dbt/rpc.py +++ b/core/dbt/rpc.py @@ -1,5 +1,23 @@ -from jsonrpc.exceptions import JSONRPCDispatchException, JSONRPCInvalidParams +from jsonrpc.exceptions import JSONRPCDispatchException, \ + JSONRPCInvalidParams, \ + JSONRPCParseError, \ + JSONRPCInvalidRequestException, \ + JSONRPCInvalidRequest +from jsonrpc import JSONRPCResponseManager +from jsonrpc.jsonrpc import JSONRPCRequest +from jsonrpc.jsonrpc2 import JSONRPC20Response +import json +import uuid +import multiprocessing +import os +import signal +import time +from collections import namedtuple + +from dbt.logger import RPC_LOGGER as logger +from dbt.logger import add_queue_handler +from dbt.compat import QueueEmpty import dbt.exceptions @@ -16,6 +34,12 @@ def __init__(self, code=None, message=None, data=None, logs=None): data=data) self.logs = logs + def __str__(self): + return ( + 'RPCException({0.code}, {0.message}, {0.data}, {1.logs})' + .format(self.error, self) + ) + @property def logs(self): return self.error.data.get('logs') @@ -66,3 +90,301 @@ def terminating(cls): cls.Error, cls.Result ] + + +def sigterm_handler(signum, frame): + raise dbt.exceptions.RPCKilledException(signum) + + +class RequestDispatcher(object): + """A special dispatcher that knows about requests.""" + def __init__(self, http_request, json_rpc_request, manager): + self.http_request = http_request + self.json_rpc_request = json_rpc_request + self.manager = manager + self.task = None + + def rpc_factory(self, task): + request_handler = RequestTaskHandler(task, + self.http_request, + self.json_rpc_request) + + def rpc_func(**kwargs): + try: + self.manager.add_request(request_handler) + return request_handler.handle(kwargs) + finally: + self.manager.mark_done(request_handler) + + return rpc_func + + def __getitem__(self, key): + # the dispatcher's keys are method names and its values are functions + # that implement the RPC calls + func = self.manager.rpc_builtin(key) + if func is not None: + return func + + task = self.manager.rpc_task(key) + return self.rpc_factory(task) + + +class RequestTaskHandler(object): + def __init__(self, task, http_request, json_rpc_request): + self.task = task + self.http_request = http_request + self.json_rpc_request = json_rpc_request + self.queue = None + self.process = None + self.started = None + self.timeout = None + self.logs = [] + self.task_id = uuid.uuid4() + + @property + def request_source(self): + return self.http_request.remote_addr + + @property + def request_id(self): + return self.json_rpc_request._id + + @property + def method(self): + return self.task.METHOD_NAME + + def _next_timeout(self): + if self.timeout is None: + return None + end = self.started + self.timeout + timeout = end - time.time() + if timeout < 0: + raise dbt.exceptions.RPCTimeoutException(self.timeout) + return timeout + + def _wait_for_results(self): + """Wait for results off the queue. If there is a timeout set, and it is + exceeded, raise an RPCTimeoutException. + """ + while True: + get_timeout = self._next_timeout() + try: + msgtype, value = self.queue.get(timeout=get_timeout) + except QueueEmpty: + raise dbt.exceptions.RPCTimeoutException(self.timeout) + + if msgtype == QueueMessageType.Log: + self.logs.append(value) + elif msgtype in QueueMessageType.terminating(): + return msgtype, value + else: + raise dbt.exceptions.InternalException( + 'Got invalid queue message type {}'.format(msgtype) + ) + + def _join_process(self): + try: + msgtype, result = self._wait_for_results() + except dbt.exceptions.RPCTimeoutException as exc: + self.process.terminate() + raise timeout_error(self.timeout) + except dbt.exceptions.Exception as exc: + raise dbt_error(exc) + except Exception as exc: + raise server_error(exc) + finally: + self.process.join() + + if msgtype == QueueMessageType.Error: + raise RPCException.from_error(result) + + return result + + def get_result(self): + try: + result = self._join_process() + except RPCException as exc: + exc.logs = self.logs + raise + + result['logs'] = self.logs + return result + + def task_bootstrap(self, kwargs): + signal.signal(signal.SIGTERM, sigterm_handler) + # the first thing we do in a new process: start logging + add_queue_handler(self.queue) + + error = None + result = None + try: + result = self.task.handle_request(**kwargs) + except RPCException as exc: + error = exc + except dbt.exceptions.Exception as exc: + logger.debug('dbt runtime exception', exc_info=True) + error = dbt_error(exc) + except Exception as exc: + logger.debug('uncaught python exception', exc_info=True) + error = server_error(exc) + + # put whatever result we got onto the queue as well. + if error is not None: + self.queue.put([QueueMessageType.Error, error.error]) + else: + self.queue.put([QueueMessageType.Result, result]) + + def handle(self, kwargs): + self.started = time.time() + self.timeout = kwargs.pop('timeout', None) + self.queue = multiprocessing.Queue() + self.process = multiprocessing.Process( + target=self.task_bootstrap, + args=(kwargs,) + ) + self.process.start() + return self.get_result() + + @property + def state(self): + if self.started is None: + return 'not started' + elif self.process is None: + return 'initializing' + elif self.process.is_alive(): + return 'running' + else: + return 'finished' + + +TaskRow = namedtuple( + 'TaskRow', + 'task_id request_id request_source method state start elapsed timeout' +) + + +class TaskManager(object): + def __init__(self): + self.tasks = {} + self.completed = {} + self._rpc_task_map = {} + self._rpc_function_map = {} + self._lock = multiprocessing.Lock() + + def add_request(self, request_handler): + self.tasks[request_handler.task_id] = request_handler + + def add_task_handler(self, task): + self._rpc_task_map[task.METHOD_NAME] = task + + def rpc_task(self, method_name): + return self._rpc_task_map[method_name] + + def process_listing(self, active=True, completed=False): + included_tasks = {} + with self._lock: + if completed: + included_tasks.update(self.completed) + if active: + included_tasks.update(self.tasks) + + table = [] + now = time.time() + for task_handler in included_tasks.values(): + start = task_handler.started + if start is not None: + elapsed = now - start + + table.append(TaskRow( + str(task_handler.task_id), task_handler.request_id, + task_handler.request_source, task_handler.method, + task_handler.state, start, elapsed, task_handler.timeout + )) + table.sort(key=lambda r: (r.state, r.start)) + result = { + 'columns': list(TaskRow._fields), + 'rows': [list(r) for r in table], + } + return result + + def process_kill(self, task_id): + # TODO: this result design is terrible + result = { + 'found': False, + 'started': False, + 'finished': False, + 'killed': False + } + task_id = uuid.UUID(task_id) + try: + task = self.tasks[task_id] + except KeyError: + # nothing to do! + return result + + result['found'] = True + + if task.process is None: + return result + pid = task.process.pid + if pid is None: + return result + + result['started'] = True + + if task.process.is_alive(): + os.kill(pid, signal.SIGINT) + result['killed'] = True + return result + + result['finished'] = True + return result + + def rpc_builtin(self, method_name): + if method_name == 'ps': + return self.process_listing + if method_name == 'kill': + return self.process_kill + return None + + def mark_done(self, request_handler): + task_id = request_handler.task_id + with self._lock: + if task_id not in self.tasks: + # lost a task! Maybe it was killed before it started. + return + self.completed[task_id] = self.tasks.pop(task_id) + + def methods(self): + return list(self._rpc_task_map) + + +class ResponseManager(JSONRPCResponseManager): + """Override the default response manager to handle request metadata and + track in-flight tasks. + """ + @classmethod + def handle(cls, http_request, task_manager): + # pretty much just copy+pasted from the original, with slight tweaks to + # preserve the request + request_str = http_request.data + if isinstance(request_str, bytes): + request_str = request_str.decode("utf-8") + + try: + data = json.loads(request_str) + except (TypeError, ValueError): + return JSONRPC20Response(error=JSONRPCParseError()._data) + + try: + request = JSONRPCRequest.from_data(data) + except JSONRPCInvalidRequestException: + return JSONRPC20Response(error=JSONRPCInvalidRequest()._data) + + dispatcher = RequestDispatcher( + http_request, + request, + task_manager + ) + + return cls.handle_request(request, dispatcher) diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index 8b177812bb5..936844549a6 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -1,5 +1,8 @@ import os +import signal +import threading +from dbt.adapters.factory import get_adapter from dbt.clients.jinja import extract_toplevel_blocks from dbt.compilation import compile_manifest from dbt.loader import load_all_projects @@ -9,6 +12,7 @@ from dbt.parser.macros import MacroParser from dbt.parser.util import ParserUtils import dbt.ui.printer +from dbt.logger import RPC_LOGGER as rpc_logger from dbt.task.runnable import GraphRunnableTask, RemoteCallable @@ -111,13 +115,47 @@ def _get_exec_node(self, name, sql, macros): self.linker = compile_manifest(self.config, self.manifest, write=False) return node + def _raise_set_error(self): + if self._raise_next_tick is not None: + raise self._raise_next_tick + + def _in_thread(self, node, thread_done): + runner = self.get_runner(node) + try: + self.node_results.append(runner.safe_run(self.manifest)) + except Exception as exc: + self._raise_next_tick = exc + finally: + thread_done.set() + def handle_request(self, name, sql, macros=None): node = self._get_exec_node(name, sql, macros) selected_uids = [node.unique_id] self.runtime_cleanup(selected_uids) - self.job_queue = self.linker.as_graph_queue(self.manifest, - selected_uids) - result = self.get_runner(node).safe_run(self.manifest) - return result.serialize() + thread_done = threading.Event() + thread = threading.Thread(target=self._in_thread, + args=(node, thread_done)) + thread.start() + try: + thread_done.wait() + except KeyboardInterrupt: + adapter = get_adapter(self.config) + if adapter.is_cancelable(): + + for conn_name in adapter.cancel_open_connections(): + rpc_logger.debug('canceled query {}'.format(conn_name)) + + thread.join() + else: + msg = ("The {} adapter does not support query " + "cancellation. Some queries may still be " + "running!".format(adapter.type())) + + rpc_logger.debug(msg) + + raise dbt.exceptions.RPCKilledException(signal.SIGINT) + + self._raise_set_error() + return self.node_results[0].serialize() diff --git a/core/dbt/task/rpc_server.py b/core/dbt/task/rpc_server.py index 084a3424371..0dfdfa4c277 100644 --- a/core/dbt/task/rpc_server.py +++ b/core/dbt/task/rpc_server.py @@ -1,149 +1,33 @@ import json -import multiprocessing -import time - -from jsonrpc import Dispatcher, JSONRPCResponseManager from werkzeug.wsgi import DispatcherMiddleware from werkzeug.wrappers import Request, Response from werkzeug.serving import run_simple from werkzeug.exceptions import NotFound -from dbt.logger import RPC_LOGGER as logger, add_queue_handler +from dbt.logger import RPC_LOGGER as logger from dbt.task.base import ConfiguredTask from dbt.task.compile import CompileTask, RemoteCompileTask from dbt.task.run import RemoteRunTask from dbt.utils import JSONEncoder -from dbt.compat import QueueEmpty -import dbt.exceptions from dbt import rpc -class RequestTaskHandler(object): - def __init__(self, task): - self.task = task - self.queue = None - self.process = None - self.started = None - self.timeout = None - self.logs = [] - - def _next_timeout(self): - if self.timeout is None: - return None - end = self.started + self.timeout - timeout = end - time.time() - if timeout < 0: - raise dbt.exceptions.RPCTimeoutException(self.timeout) - return timeout - - def _wait_for_results(self): - """Wait for results off the queue. If there is a timeout set, and it is - exceeded, raise an RPCTimeoutException. - """ - while True: - get_timeout = self._next_timeout() - try: - msgtype, value = self.queue.get(timeout=get_timeout) - except QueueEmpty: - raise dbt.exceptions.RPCTimeoutException(self.timeout) - - if msgtype == rpc.QueueMessageType.Log: - self.logs.append(value) - elif msgtype in rpc.QueueMessageType.terminating(): - return msgtype, value - else: - raise dbt.exceptions.InternalException( - 'Got invalid queue message type {}'.format(msgtype) - ) - - def _join_process(self): - try: - msgtype, result = self._wait_for_results() - except dbt.exceptions.RPCTimeoutException as exc: - self.process.terminate() - raise rpc.timeout_error(self.timeout) - except dbt.exceptions.Exception as exc: - raise rpc.dbt_error(exc) - except Exception as exc: - raise rpc.server_error(exc) - finally: - self.process.join() - - if msgtype == rpc.QueueMessageType.Error: - raise rpc.RPCException.from_error(result) - - return result - - def get_result(self): - try: - result = self._join_process() - except rpc.RPCException as exc: - exc.logs = self.logs - raise - - result['logs'] = self.logs - return result - - def task_bootstrap(self, kwargs): - # the first thing we do in a new process: start logging - add_queue_handler(self.queue) - - error = None - result = None - try: - result = self.task.handle_request(**kwargs) - except rpc.RPCException as exc: - error = exc - except dbt.exceptions.Exception as exc: - logger.debug('dbt runtime exception', exc_info=True) - error = rpc.dbt_error(exc) - except Exception as exc: - logger.debug('uncaught python exception', exc_info=True) - error = rpc.server_error(exc) - - # put whatever result we got onto the queue as well. - if error is not None: - self.queue.put([rpc.QueueMessageType.Error, error.error]) - else: - self.queue.put([rpc.QueueMessageType.Result, result]) - - def handle(self, kwargs): - self.started = time.time() - self.timeout = kwargs.pop('timeout', None) - self.queue = multiprocessing.Queue() - self.process = multiprocessing.Process( - target=self.task_bootstrap, - args=(kwargs,) - ) - self.process.start() - return self.get_result() - - @classmethod - def factory(cls, task): - def handler(**kwargs): - return cls(task).handle(kwargs) - return handler - - class RPCServerTask(ConfiguredTask): def __init__(self, args, config, tasks=None): super(RPCServerTask, self).__init__(args, config) # compile locally - self.compile_task = CompileTask(args, config) - self.compile_task.run() - self.dispatcher = Dispatcher() + self.manifest = self._compile_manifest() + self.task_manager = rpc.TaskManager() tasks = tasks or [RemoteCompileTask, RemoteRunTask] for cls in tasks: - self.register(cls(args, config, self.manifest)) - - def register(self, task): - self.dispatcher.add_method(RequestTaskHandler.factory(task), - name=task.METHOD_NAME) + task = cls(args, config, self.manifest) + self.task_manager.add_task_handler(task) - @property - def manifest(self): - return self.compile_task.manifest + def _compile_manifest(self): + compile_task = CompileTask(self.args, self.config) + compile_task.run() + return compile_task.manifest def run(self): host = self.args.host @@ -159,7 +43,7 @@ def run(self): ) logger.info( - 'Supported methods: {}'.format(list(self.dispatcher.keys())) + 'Supported methods: {}'.format(self.task_manager.methods()) ) logger.info( @@ -171,13 +55,18 @@ def run(self): '/jsonrpc': self.handle_jsonrpc_request, }) - run_simple(host, port, app, processes=self.config.threads) + # we have to run in threaded mode if we want to share subprocess + # handles, which is the easiest way to implement `kill` (it makes `ps` + # easier as well). The alternative involves tracking metadata+state in + # a multiprocessing.Manager, adds polling the manager to the request + # task handler and in general gets messy fast. + run_simple(host, port, app, threaded=True) @Request.application def handle_jsonrpc_request(self, request): msg = 'Received request ({0}) from {0.remote_addr}, data={0.data}' logger.info(msg.format(request)) - response = JSONRPCResponseManager.handle(request.data, self.dispatcher) + response = rpc.ResponseManager.handle(request, self.task_manager) json_data = json.dumps(response.data, cls=JSONEncoder) response = Response(json_data, mimetype='application/json') # this looks and feels dumb, but our json encoder converts decimals and diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index d06bf83c35c..e210dc6a5b6 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -4,7 +4,6 @@ import time from abc import abstractmethod from multiprocessing.dummy import Pool as ThreadPool -from jsonrpc.exceptions import JSONRPCInvalidParams from dbt import rpc from dbt.task.base import ConfiguredTask From 8410be848fead790f3c7032fe71e8de6c49a4c85 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 28 Mar 2019 10:26:03 -0600 Subject: [PATCH 5/9] fix the methods list --- core/dbt/rpc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/dbt/rpc.py b/core/dbt/rpc.py index 979b7ecb76c..ee42e624332 100644 --- a/core/dbt/rpc.py +++ b/core/dbt/rpc.py @@ -356,7 +356,8 @@ def mark_done(self, request_handler): self.completed[task_id] = self.tasks.pop(task_id) def methods(self): - return list(self._rpc_task_map) + rpc_builtin_methods = ['ps', 'kill'] + return list(self._rpc_task_map) + rpc_builtin_methods class ResponseManager(JSONRPCResponseManager): From 182714b6b801107d1e1baec0dd49c218c52b1416 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 28 Mar 2019 10:26:23 -0600 Subject: [PATCH 6/9] handle ctrl+c during parsing, etc --- core/dbt/task/compile.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index 936844549a6..64cbefc9daf 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -129,16 +129,18 @@ def _in_thread(self, node, thread_done): thread_done.set() def handle_request(self, name, sql, macros=None): - node = self._get_exec_node(name, sql, macros) + # we could get a ctrl+c at any time, including during parsing. + thread = None + try: + node = self._get_exec_node(name, sql, macros) - selected_uids = [node.unique_id] - self.runtime_cleanup(selected_uids) + selected_uids = [node.unique_id] + self.runtime_cleanup(selected_uids) - thread_done = threading.Event() - thread = threading.Thread(target=self._in_thread, - args=(node, thread_done)) - thread.start() - try: + thread_done = threading.Event() + thread = threading.Thread(target=self._in_thread, + args=(node, thread_done)) + thread.start() thread_done.wait() except KeyboardInterrupt: adapter = get_adapter(self.config) @@ -146,8 +148,8 @@ def handle_request(self, name, sql, macros=None): for conn_name in adapter.cancel_open_connections(): rpc_logger.debug('canceled query {}'.format(conn_name)) - - thread.join() + if thread: + thread.join() else: msg = ("The {} adapter does not support query " "cancellation. Some queries may still be " From 6c8e74bac9a6aea959009342b90ad15353e6bc81 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 2 Apr 2019 15:15:37 -0600 Subject: [PATCH 7/9] tests, fight with (test-only?) deadlocks when adding more threads stops helping, add more sleeps --- core/dbt/rpc.py | 4 + .../042_sources_test/test_sources.py | 177 ++++++++++++++++-- 2 files changed, 169 insertions(+), 12 deletions(-) diff --git a/core/dbt/rpc.py b/core/dbt/rpc.py index ee42e624332..2e25b9e5655 100644 --- a/core/dbt/rpc.py +++ b/core/dbt/rpc.py @@ -221,6 +221,10 @@ def task_bootstrap(self, kwargs): result = self.task.handle_request(**kwargs) except RPCException as exc: error = exc + except dbt.exceptions.RPCKilledException as exc: + # do NOT log anything here, you risk triggering a deadlock on the + # queue handler we inserted above + error = dbt_error(exc) except dbt.exceptions.Exception as exc: logger.debug('dbt runtime exception', exc_info=True) error = dbt_error(exc) diff --git a/test/integration/042_sources_test/test_sources.py b/test/integration/042_sources_test/test_sources.py index dd309cfb69e..2f6f37b4a45 100644 --- a/test/integration/042_sources_test/test_sources.py +++ b/test/integration/042_sources_test/test_sources.py @@ -7,6 +7,7 @@ from base64 import standard_b64encode as b64 import requests import socket +import threading import time @@ -278,7 +279,8 @@ def __init__(self, cli_vars=None): handle_and_check_args.extend(['--vars', cli_vars]) super(ServerProcess, self).__init__( target=handle_and_check, - args=(handle_and_check_args,)) + args=(handle_and_check_args,), + name='ServerProcess') def is_up(self): sock = socket.socket() @@ -300,6 +302,36 @@ def start(self): raise Exception('server never appeared!') +def query_url(url, query): + headers = {'content-type': 'application/json'} + return requests.post(url, headers=headers, data=json.dumps(query)) + + +class BackgroundQueryProcess(multiprocessing.Process): + def __init__(self, query, url, group=None, name=None): + parent, child = multiprocessing.Pipe() + self.parent_pipe = parent + self.child_pipe = child + self.query = query + self.url = url + super(BackgroundQueryProcess, self).__init__(group=group, name=name) + + def run(self): + try: + result = query_url(self.url, self.query).json() + except Exception as exc: + self.child_pipe.send(('error', str(exc))) + else: + self.child_pipe.send(('result', result)) + + def wait_result(self): + result_type, result = self.parent_pipe.recv() + self.join() + if result_type == 'error': + raise Exception(result) + else: + return result + _select_from_ephemeral = '''with __dbt__CTE__ephemeral_model as ( @@ -328,7 +360,8 @@ def project_config(self): 'macro-paths': ['test/integration/042_sources_test/macros'], } - def build_query(self, method, kwargs, sql=None, test_request_id=1, macros=None): + def build_query(self, method, kwargs, sql=None, test_request_id=1, + macros=None): body_data = '' if sql is not None: body_data += sql @@ -346,15 +379,34 @@ def build_query(self, method, kwargs, sql=None, test_request_id=1, macros=None): 'id': test_request_id } - def perform_query(self, query): - url = 'http://localhost:{}/jsonrpc'.format(self._server.port) - headers = {'content-type': 'application/json'} - response = requests.post(url, headers=headers, data=json.dumps(query)) - return response + @property + def url(self): + return 'http://localhost:{}/jsonrpc'.format(self._server.port) def query(self, _method, _sql=None, _test_request_id=1, macros=None, **kwargs): built = self.build_query(_method, kwargs, _sql, _test_request_id, macros) - return self.perform_query(built) + return query_url(self.url, built) + + def handle_result(self, bg_query, pipe): + result_type, result = pipe.recv() + bg_query.join() + if result_type == 'error': + raise result + else: + return result + + def background_query(self, _method, _sql=None, _test_request_id=1, + _block=False, macros=None, **kwargs): + built = self.build_query(_method, kwargs, _sql, _test_request_id, + macros) + + url = 'http://localhost:{}/jsonrpc'.format(self._server.port) + name = _method + if 'name' in kwargs: + name += ' ' + kwargs['name'] + bg_query = BackgroundQueryProcess(built, url, name=name) + bg_query.start() + return bg_query def assertResultHasTimings(self, result, *names): self.assertIn('timing', result) @@ -375,15 +427,15 @@ def assertIsResult(self, data): self.assertNotIn('error', data) return data['result'] - def assertIsError(self, data): - self.assertEqual(data['id'], 1) + def assertIsError(self, data, id_=1): + self.assertEqual(data['id'], id_) self.assertEqual(data['jsonrpc'], '2.0') self.assertIn('error', data) self.assertNotIn('result', data) return data['error'] - def assertIsErrorWithCode(self, data, code): - error = self.assertIsError(data) + def assertIsErrorWithCode(self, data, code, id_=1): + error = self.assertIsError(data, id_) self.assertIn('code', error) self.assertIn('message', error) self.assertEqual(error['code'], code) @@ -632,6 +684,107 @@ def test_run_postgres(self): table={'column_names': ['id'], 'rows': [[1.0]]} ) + @use_profile('postgres') + def test_ps_kill_postgres(self): + done_query = self.query('compile', 'select 1 as id', name='done').json() + self.assertIsResult(done_query) + pg_sleeper, sleep_task_id, request_id = self._get_sleep_query() + + empty_ps_result = self.query('ps', completed=False, active=False).json() + result = self.assertIsResult(empty_ps_result) + self.assertEqual(len(result['rows']), 0) + + sleeper_ps_result = self.query('ps', completed=False, active=True).json() + result = self.assertIsResult(sleeper_ps_result) + self.assertEqual(len(result['rows']), 1) + self.assertEqual(len(result['rows'][0]), len(result['columns'])) + rowdict = [{k: v for k, v in zip(result['columns'], row)} for row in result['rows']] + self.assertEqual(rowdict[0]['request_id'], request_id) + self.assertEqual(rowdict[0]['method'], 'run') + self.assertEqual(rowdict[0]['state'], 'running') + self.assertEqual(rowdict[0]['timeout'], None) + + complete_ps_result = self.query('ps', completed=True, active=False).json() + result = self.assertIsResult(complete_ps_result) + self.assertEqual(len(result['rows']), 1) + self.assertEqual(len(result['rows'][0]), len(result['columns'])) + rowdict = [{k: v for k, v in zip(result['columns'], row)} for row in result['rows']] + self.assertEqual(rowdict[0]['request_id'], 1) + self.assertEqual(rowdict[0]['method'], 'compile') + self.assertEqual(rowdict[0]['state'], 'finished') + self.assertEqual(rowdict[0]['timeout'], None) + + all_ps_result = self.query('ps', completed=True, active=True).json() + result = self.assertIsResult(all_ps_result) + self.assertEqual(len(result['rows']), 2) + self.assertEqual(len(result['rows'][0]), len(result['columns'])) + self.assertEqual(len(result['rows'][1]), len(result['columns'])) + rowdict = [{k: v for k, v in zip(result['columns'], row)} for row in result['rows']] + rowdict.sort(key=lambda r: r['start']) + self.assertEqual(rowdict[0]['request_id'], 1) + self.assertEqual(rowdict[0]['method'], 'compile') + self.assertEqual(rowdict[0]['state'], 'finished') + self.assertEqual(rowdict[0]['timeout'], None) + self.assertEqual(rowdict[1]['request_id'], request_id) + self.assertEqual(rowdict[1]['method'], 'run') + self.assertEqual(rowdict[1]['state'], 'running') + self.assertEqual(rowdict[1]['timeout'], None) + + self.kill_and_assert(pg_sleeper, sleep_task_id, request_id) + + def kill_and_assert(self, pg_sleeper, task_id, request_id): + kill_result = self.query('kill', task_id=task_id).json() + kill_time = time.time() + result = self.assertIsResult(kill_result) + self.assertTrue(result['killed']) + + sleeper_result = pg_sleeper.wait_result() + result_time = time.time() + error = self.assertIsErrorWithCode(sleeper_result, 10009, request_id) + self.assertEqual(error['message'], 'RPC process killed') + self.assertIn('data', error) + error_data = error['data'] + self.assertEqual(error_data['signum'], 2) + self.assertEqual(error_data['message'], 'RPC process killed by signal 2') + self.assertIn('logs', error_data) + # it should take less than 5s to kill the process if things are working + # properly + self.assertLess(result_time, kill_time + 5) + return error_data + + def _get_sleep_query(self): + request_id = 90890 + pg_sleeper = self.background_query( + 'run', + 'select pg_sleep(15)', + _test_request_id=request_id, + name='sleeper', + ) + + for _ in range(20): + time.sleep(0.2) + sleeper_ps_result = self.query('ps', completed=False, active=True).json() + result = self.assertIsResult(sleeper_ps_result) + rows = [{k: v for k, v in zip(result['columns'], row)} for row in result['rows']] + for row in rows: + if row['request_id'] == request_id and row['state'] == 'running': + return pg_sleeper, row['task_id'], request_id + + self.assertTrue(False, 'request ID never found running!') + + @use_profile('postgres') + def test_ps_kill_longwait_postgres(self): + pg_sleeper, sleep_task_id, request_id = self._get_sleep_query() + + # the test above frequently kills the process during parsing of the + # requested node. That's also a useful test, but we should test that + # we cancel the in-progress sleep query. + time.sleep(3) + + error_data = self.kill_and_assert(pg_sleeper, sleep_task_id, request_id) + # we should have logs if we did anything + self.assertTrue(len(error_data['logs']) > 0) + @use_profile('postgres') def test_invalid_requests_postgres(self): data = self.query( From 3b357340fdb375e49c03f5b75e5f49af55e6c3f9 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 28 Mar 2019 11:45:43 -0600 Subject: [PATCH 8/9] skip the timing assert on python 2.x --- .../042_sources_test/test_sources.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/integration/042_sources_test/test_sources.py b/test/integration/042_sources_test/test_sources.py index 2f6f37b4a45..0e4c92a00ed 100644 --- a/test/integration/042_sources_test/test_sources.py +++ b/test/integration/042_sources_test/test_sources.py @@ -1,15 +1,14 @@ -import unittest -from datetime import datetime, timedelta import json -import os - import multiprocessing -from base64 import standard_b64encode as b64 -import requests +import os import socket -import threading +import sys import time +import unittest +from base64 import standard_b64encode as b64 +from datetime import datetime, timedelta +import requests from dbt.exceptions import CompilationException from test.integration.base import DBTIntegrationTest, use_profile, AnyFloat, \ @@ -748,8 +747,9 @@ def kill_and_assert(self, pg_sleeper, task_id, request_id): self.assertEqual(error_data['message'], 'RPC process killed by signal 2') self.assertIn('logs', error_data) # it should take less than 5s to kill the process if things are working - # properly - self.assertLess(result_time, kill_time + 5) + # properly. On python 2.x, things do not work properly. + if sys.version_info.major > 2: + self.assertLess(result_time, kill_time + 5) return error_data def _get_sleep_query(self): From 2654c795485e7cc439e9af050a9880c96a634efd Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Wed, 3 Apr 2019 10:01:36 -0600 Subject: [PATCH 9/9] PR feedback --- core/dbt/rpc.py | 3 +-- test/integration/042_sources_test/test_sources.py | 12 ++++-------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/core/dbt/rpc.py b/core/dbt/rpc.py index 2e25b9e5655..de1eae7f826 100644 --- a/core/dbt/rpc.py +++ b/core/dbt/rpc.py @@ -306,8 +306,7 @@ def process_listing(self, active=True, completed=False): )) table.sort(key=lambda r: (r.state, r.start)) result = { - 'columns': list(TaskRow._fields), - 'rows': [list(r) for r in table], + 'rows': [dict(r._asdict()) for r in table], } return result diff --git a/test/integration/042_sources_test/test_sources.py b/test/integration/042_sources_test/test_sources.py index 0e4c92a00ed..ab56d077443 100644 --- a/test/integration/042_sources_test/test_sources.py +++ b/test/integration/042_sources_test/test_sources.py @@ -696,8 +696,7 @@ def test_ps_kill_postgres(self): sleeper_ps_result = self.query('ps', completed=False, active=True).json() result = self.assertIsResult(sleeper_ps_result) self.assertEqual(len(result['rows']), 1) - self.assertEqual(len(result['rows'][0]), len(result['columns'])) - rowdict = [{k: v for k, v in zip(result['columns'], row)} for row in result['rows']] + rowdict = result['rows'] self.assertEqual(rowdict[0]['request_id'], request_id) self.assertEqual(rowdict[0]['method'], 'run') self.assertEqual(rowdict[0]['state'], 'running') @@ -706,8 +705,7 @@ def test_ps_kill_postgres(self): complete_ps_result = self.query('ps', completed=True, active=False).json() result = self.assertIsResult(complete_ps_result) self.assertEqual(len(result['rows']), 1) - self.assertEqual(len(result['rows'][0]), len(result['columns'])) - rowdict = [{k: v for k, v in zip(result['columns'], row)} for row in result['rows']] + rowdict = result['rows'] self.assertEqual(rowdict[0]['request_id'], 1) self.assertEqual(rowdict[0]['method'], 'compile') self.assertEqual(rowdict[0]['state'], 'finished') @@ -716,9 +714,7 @@ def test_ps_kill_postgres(self): all_ps_result = self.query('ps', completed=True, active=True).json() result = self.assertIsResult(all_ps_result) self.assertEqual(len(result['rows']), 2) - self.assertEqual(len(result['rows'][0]), len(result['columns'])) - self.assertEqual(len(result['rows'][1]), len(result['columns'])) - rowdict = [{k: v for k, v in zip(result['columns'], row)} for row in result['rows']] + rowdict = result['rows'] rowdict.sort(key=lambda r: r['start']) self.assertEqual(rowdict[0]['request_id'], 1) self.assertEqual(rowdict[0]['method'], 'compile') @@ -765,7 +761,7 @@ def _get_sleep_query(self): time.sleep(0.2) sleeper_ps_result = self.query('ps', completed=False, active=True).json() result = self.assertIsResult(sleeper_ps_result) - rows = [{k: v for k, v in zip(result['columns'], row)} for row in result['rows']] + rows = result['rows'] for row in rows: if row['request_id'] == request_id and row['state'] == 'running': return pg_sleeper, row['task_id'], request_id