-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
connections.py
182 lines (152 loc) · 6.2 KB
/
connections.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import abc
import time
from typing import List, Optional, Tuple, Any, Iterable, Dict
import agate
import dbt.clients.agate_helper
import dbt.exceptions
from dbt.adapters.base import BaseConnectionManager
from dbt.contracts.connection import Connection, ConnectionState, AdapterResponse
from dbt.events.functions import fire_event
from dbt.events.types import ConnectionUsed, SQLQuery, SQLCommit, SQLQueryStatus
from dbt.events.contextvars import get_node_info
from dbt.utils import cast_to_str
class SQLConnectionManager(BaseConnectionManager):
"""The default connection manager with some common SQL methods implemented.
Methods to implement:
- exception_handler
- cancel
- get_response
- open
"""
@abc.abstractmethod
def cancel(self, connection: Connection):
"""Cancel the given connection."""
raise dbt.exceptions.NotImplementedError("`cancel` is not implemented for this adapter!")
def cancel_open(self) -> List[str]:
names = []
this_connection = self.get_if_exists()
with self.lock:
for connection in self.thread_connections.values():
if connection is this_connection:
continue
# if the connection failed, the handle will be None so we have
# nothing to cancel.
if connection.handle is not None and connection.state == ConnectionState.OPEN:
self.cancel(connection)
if connection.name is not None:
names.append(connection.name)
return names
def add_query(
self,
sql: str,
auto_begin: bool = True,
bindings: Optional[Any] = None,
abridge_sql_log: bool = False,
) -> Tuple[Connection, Any]:
connection = self.get_thread_connection()
if auto_begin and connection.transaction_open is False:
self.begin()
fire_event(
ConnectionUsed(
conn_type=self.TYPE,
conn_name=cast_to_str(connection.name),
node_info=get_node_info(),
)
)
with self.exception_handler(sql):
if abridge_sql_log:
log_sql = "{}...".format(sql[:512])
else:
log_sql = sql
fire_event(
SQLQuery(
conn_name=cast_to_str(connection.name), sql=log_sql, node_info=get_node_info()
)
)
pre = time.time()
cursor = connection.handle.cursor()
cursor.execute(sql, bindings)
fire_event(
SQLQueryStatus(
status=str(self.get_response(cursor)),
elapsed=round((time.time() - pre)),
node_info=get_node_info(),
)
)
return connection, cursor
@classmethod
@abc.abstractmethod
def get_response(cls, cursor: Any) -> AdapterResponse:
"""Get the status of the cursor."""
raise dbt.exceptions.NotImplementedError(
"`get_response` is not implemented for this adapter!"
)
@classmethod
def process_results(
cls, column_names: Iterable[str], rows: Iterable[Any]
) -> List[Dict[str, Any]]:
# TODO CT-211
unique_col_names = dict() # type: ignore[var-annotated]
# TODO CT-211
for idx in range(len(column_names)): # type: ignore[arg-type]
# TODO CT-211
col_name = column_names[idx] # type: ignore[index]
if col_name in unique_col_names:
unique_col_names[col_name] += 1
# TODO CT-211
column_names[idx] = f"{col_name}_{unique_col_names[col_name]}" # type: ignore[index] # noqa
else:
# TODO CT-211
unique_col_names[column_names[idx]] = 1 # type: ignore[index]
return [dict(zip(column_names, row)) for row in rows]
@classmethod
def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> agate.Table:
data: List[Any] = []
column_names: List[str] = []
if cursor.description is not None:
column_names = [col[0] for col in cursor.description]
if limit:
rows = cursor.fetchmany(limit)
else:
rows = cursor.fetchall()
data = cls.process_results(column_names, rows)
return dbt.clients.agate_helper.table_from_data_flat(data, column_names)
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None
) -> Tuple[AdapterResponse, agate.Table]:
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql, auto_begin)
response = self.get_response(cursor)
if fetch:
table = self.get_result_from_cursor(cursor, limit)
else:
table = dbt.clients.agate_helper.empty_table()
return response, table
def add_begin_query(self):
return self.add_query("BEGIN", auto_begin=False)
def add_commit_query(self):
return self.add_query("COMMIT", auto_begin=False)
def add_select_query(self, sql: str) -> Tuple[Connection, Any]:
sql = self._add_query_comment(sql)
return self.add_query(sql, auto_begin=False)
def begin(self):
connection = self.get_thread_connection()
if connection.transaction_open is True:
raise dbt.exceptions.DbtInternalError(
'Tried to begin a new transaction on connection "{}", but '
"it already had one open!".format(connection.name)
)
self.add_begin_query()
connection.transaction_open = True
return connection
def commit(self):
connection = self.get_thread_connection()
if connection.transaction_open is False:
raise dbt.exceptions.DbtInternalError(
'Tried to commit transaction on connection "{}", but '
"it does not have one open!".format(connection.name)
)
fire_event(SQLCommit(conn_name=connection.name, node_info=get_node_info()))
self.add_commit_query()
connection.transaction_open = False
return connection