-
Notifications
You must be signed in to change notification settings - Fork 7
/
sql_commands.py
218 lines (178 loc) · 7.2 KB
/
sql_commands.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import base64
import signal
import threading
from datetime import datetime
from typing import Dict, Any
from dbt.flags import env_set_truthy
from dbt.adapters.factory import get_adapter
from dbt.clients.jinja import extract_toplevel_blocks
from dbt.config.runtime import RuntimeConfig
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import RPCNode
from dbt_rpc.contracts.rpc import RPCExecParameters
from dbt_rpc.contracts.rpc import RemoteExecutionResult
from dbt.exceptions import RPCKilledException, DbtInternalError
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.parser.manifest import process_node, process_macro
from dbt_rpc.parser.rpc import RPCCallParser, RPCMacroParser
from dbt_rpc.rpc.error import invalid_params
from dbt_rpc.rpc.node_runners import RPCCompileRunner, RPCExecuteRunner
from dbt.task.compile import CompileTask
from dbt.task.run import RunTask
from .base import RPCTask
SINGLE_THREADED_HANDLER = env_set_truthy('DBT_SINGLE_THREADED_HANDLER')
def add_new_refs(
manifest: Manifest,
config: RuntimeConfig,
node: RPCNode,
macros: Dict[str, Any]
) -> None:
"""Given a new node that is not in the manifest, insert the new node
into it as if it were part of regular ref processing.
"""
if config.args.single_threaded or SINGLE_THREADED_HANDLER:
manifest = manifest.deepcopy()
# it's ok for macros to silently override a local project macro name
manifest.macros.update(macros)
for macro in macros.values():
process_macro(config, manifest, macro)
# We used to do 'manifest.add_nodes({node.unique_id: node}) here, but the
# node has already been added to the Manifest by the RPCCallParser
# now that we save nodes to the Manifest instead of ParseResults.
process_node(config, manifest, node)
class RemoteRunSQLTask(RPCTask[RPCExecParameters]):
def runtime_cleanup(self, selected_uids):
"""Do some pre-run cleanup that is usually performed in Task __init__.
"""
self.run_count = 0
self.num_nodes = len(selected_uids)
self.node_results = []
self._skipped_children = {}
self._skipped_children = {}
self._raise_next_tick = None
def decode_sql(self, sql: str) -> str:
"""Base64 decode a string. This should only be used for sql in calls.
:param str sql: The base64 encoded form of the original utf-8 string
:return str: The decoded utf-8 string
"""
# JSON is defined as using "unicode", we'll go a step further and
# mandate utf-8 (though for the base64 part, it doesn't really matter!)
base64_sql_bytes = str(sql).encode('utf-8')
try:
sql_bytes = base64.b64decode(base64_sql_bytes, validate=True)
except ValueError:
self.raise_invalid_base64(sql)
return sql_bytes.decode('utf-8')
@staticmethod
def raise_invalid_base64(sql):
raise invalid_params(
data={
'message': 'invalid base64-encoded sql input',
'sql': str(sql),
}
)
def _extract_request_data(self, data):
data = self.decode_sql(data)
macro_blocks = []
data_chunks = []
for block in extract_toplevel_blocks(data):
if block.block_type_name == 'macro':
macro_blocks.append(block.full_block)
else:
data_chunks.append(block.full_block)
macros = '\n'.join(macro_blocks)
sql = ''.join(data_chunks)
return sql, macros
def _get_exec_node(self):
if self.manifest is None:
raise DbtInternalError(
'manifest not set in _get_exec_node'
)
macro_overrides = {}
macros = self.args.macros
sql, macros = self._extract_request_data(self.args.sql)
if macros:
macro_parser = RPCMacroParser(self.config, self.manifest)
for node in macro_parser.parse_remote(macros):
macro_overrides[node.unique_id] = node
self.manifest.macros.update(macro_overrides)
rpc_parser = RPCCallParser(
project=self.config,
manifest=self.manifest,
root_project=self.config,
)
rpc_node = rpc_parser.parse_remote(sql, self.args.name, self.args.language)
add_new_refs(
manifest=self.manifest,
config=self.config,
node=rpc_node,
macros=macro_overrides
)
# don't write our new, weird manifest!
adapter = get_adapter(self.config)
compiler = adapter.get_compiler()
self.graph = compiler.compile(self.manifest, write=False)
# previously, this compiled the ancestors, but they are compiled at
# runtime now.
return rpc_node
def _raise_set_error(self):
if self._raise_next_tick is not None:
raise self._raise_next_tick
def _in_thread(self, node, thread_done):
runner = self.get_runner(node)
try:
self.node_results.append(runner.safe_run(self.manifest))
except Exception as exc:
logger.debug('Got exception {}'.format(exc), exc_info=True)
self._raise_next_tick = exc
finally:
thread_done.set()
def set_args(self, params: RPCExecParameters):
self.args.name = params.name
self.args.sql = params.sql
self.args.macros = params.macros
self.args.language = params.language
def handle_request(self) -> RemoteExecutionResult:
# we could get a ctrl+c at any time, including during parsing.
thread = None
started = datetime.utcnow()
try:
node = self._get_exec_node()
selected_uids = [node.unique_id]
self.runtime_cleanup(selected_uids)
thread_done = threading.Event()
thread = threading.Thread(target=self._in_thread,
args=(node, thread_done))
thread.start()
thread_done.wait()
except KeyboardInterrupt:
adapter = get_adapter(self.config) # type: ignore
if adapter.is_cancelable():
for conn_name in adapter.cancel_open_connections():
logger.debug('canceled query {}'.format(conn_name))
if thread:
thread.join()
else:
msg = ("The {} adapter does not support query "
"cancellation. Some queries may still be "
"running!".format(adapter.type()))
logger.debug(msg)
raise RPCKilledException(signal.SIGINT)
self._raise_set_error()
ended = datetime.utcnow()
elapsed = (ended - started).total_seconds()
return self.get_result(
results=self.node_results,
elapsed_time=elapsed,
generated_at=ended,
)
def interpret_results(self, results):
return True
class RemoteCompileTask(RemoteRunSQLTask, CompileTask):
METHOD_NAME = 'compile_sql'
def get_runner_type(self, _):
return RPCCompileRunner
class RemoteRunTask(RemoteRunSQLTask, RunTask):
METHOD_NAME = 'run_sql'
def get_runner_type(self, _):
return RPCExecuteRunner