-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement retries in BQ adapter #1963
Changes from 12 commits
21da0ed
3b45659
86f0609
b602c9c
b2c1727
3b696ee
ad0bd87
3aabe2d
0347238
33e75f8
118344b
c0e8540
9b9c1db
e03fd44
b548375
40c9328
0940309
460d73f
7c8a21d
e6b4a12
ce4c58a
5d181c3
c5c7932
43959dc
37a1288
0356a74
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -7,6 +7,7 @@ | |||||
import google.oauth2 | ||||||
import google.cloud.exceptions | ||||||
import google.cloud.bigquery | ||||||
from google.api_core import retry | ||||||
|
||||||
import dbt.clients.agate_helper | ||||||
import dbt.exceptions | ||||||
|
@@ -35,6 +36,7 @@ class BigQueryCredentials(Credentials): | |||||
timeout_seconds: Optional[int] = 300 | ||||||
location: Optional[str] = None | ||||||
priority: Optional[Priority] = None | ||||||
retries: Optional[int] = 1 | ||||||
_ALIASES = { | ||||||
'project': 'database', | ||||||
'dataset': 'schema', | ||||||
|
@@ -57,6 +59,9 @@ class BigQueryConnectionManager(BaseConnectionManager): | |||||
'https://www.googleapis.com/auth/drive') | ||||||
|
||||||
QUERY_TIMEOUT = 300 | ||||||
RETRIES = 1 | ||||||
DEFAULT_INITIAL_DELAY = 1.0 # Seconds | ||||||
DEFAULT_MAXIMUM_DELAY = 1.0 # Seconds | ||||||
|
||||||
@classmethod | ||||||
def handle_error(cls, error, message, sql): | ||||||
|
@@ -170,6 +175,11 @@ def get_timeout(cls, conn): | |||||
credentials = conn.credentials | ||||||
return credentials.timeout_seconds | ||||||
|
||||||
@classmethod | ||||||
def get_retries(cls, conn): | ||||||
credentials = conn['credentials'] | ||||||
return credentials.get('retries', cls.RETRIES) | ||||||
|
||||||
@classmethod | ||||||
def get_table_from_response(cls, resp): | ||||||
column_names = [field.name for field in resp.schema] | ||||||
|
@@ -182,21 +192,18 @@ def raw_execute(self, sql, fetch=False): | |||||
|
||||||
logger.debug('On {}: {}', conn.name, sql) | ||||||
|
||||||
job_config = google.cloud.bigquery.QueryJobConfig() | ||||||
job_config.use_legacy_sql = False | ||||||
job_params = {'use_legacy_sql': False} | ||||||
|
||||||
priority = conn.credentials.priority | ||||||
if priority == Priority.Batch: | ||||||
job_config.priority = google.cloud.bigquery.QueryPriority.BATCH | ||||||
priority = conn.credentials.get('priority', 'interactive') | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The rest of this change where you're converting the job params to a dict is fine, but this should be updated to 0.15.x syntax. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be done! |
||||||
if priority == 'batch': | ||||||
kconvey marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
job_params['priority'] = google.cloud.bigquery.QueryPriority.BATCH | ||||||
else: | ||||||
job_config.priority = \ | ||||||
google.cloud.bigquery.QueryPriority.INTERACTIVE | ||||||
job_params[ | ||||||
'priority'] = google.cloud.bigquery.QueryPriority.INTERACTIVE | ||||||
|
||||||
query_job = client.query(sql, job_config) | ||||||
fn = lambda: self._query_and_results(client, sql, conn, job_params) | ||||||
|
||||||
# this blocks until the query has completed | ||||||
with self.exception_handler(sql): | ||||||
iterator = query_job.result() | ||||||
query_job, iterator = self._retry_and_handle(msg=sql, conn=conn, fn=fn) | ||||||
|
||||||
return query_job, iterator | ||||||
|
||||||
|
@@ -243,8 +250,8 @@ def create_bigquery_table(self, database, schema, table_name, callback, | |||||
view = google.cloud.bigquery.Table(view_ref) | ||||||
callback(view) | ||||||
|
||||||
with self.exception_handler(sql): | ||||||
client.create_table(view) | ||||||
fn = lambda: client.create_table(view) | ||||||
self._retry_and_handle(msg=sql, conn=conn, fn=fn) | ||||||
|
||||||
def create_view(self, database, schema, table_name, sql): | ||||||
def callback(table): | ||||||
|
@@ -257,16 +264,11 @@ def create_table(self, database, schema, table_name, sql): | |||||
conn = self.get_thread_connection() | ||||||
client = conn.handle | ||||||
|
||||||
table_ref = self.table_ref(database, schema, table_name, conn) | ||||||
kconvey marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
job_config = google.cloud.bigquery.QueryJobConfig() | ||||||
job_config.destination = table_ref | ||||||
job_config.write_disposition = 'WRITE_TRUNCATE' | ||||||
|
||||||
query_job = client.query(sql, job_config=job_config) | ||||||
job_params = {'destination': table_ref, | ||||||
'write_disposition': 'WRITE_TRUNCATE'} | ||||||
|
||||||
# this waits for the job to complete | ||||||
with self.exception_handler(sql): | ||||||
query_job.result(timeout=self.get_timeout(conn)) | ||||||
fn = lambda: self._query_and_results(client, sql, conn, job_params) | ||||||
self._retry_and_handle(msg=sql, conn=conn, fn=fn) | ||||||
|
||||||
def create_date_partitioned_table(self, database, schema, table_name): | ||||||
def callback(table): | ||||||
|
@@ -295,15 +297,76 @@ def drop_dataset(self, database, schema): | |||||
dataset = self.dataset(database, schema, conn) | ||||||
client = conn.handle | ||||||
|
||||||
with self.exception_handler('drop dataset'): | ||||||
client.delete_dataset( | ||||||
dataset, delete_contents=True, not_found_ok=True | ||||||
) | ||||||
def _drop_tables_then_dataset(): | ||||||
for table in client.list_tables(dataset): | ||||||
kconvey marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
client.delete_table(table.reference) | ||||||
client.delete_dataset(dataset) | ||||||
|
||||||
|
||||||
self._retry_and_handle( | ||||||
msg='drop dataset', conn=conn, fn=_drop_tables_then_dataset) | ||||||
|
||||||
def create_dataset(self, database, schema): | ||||||
conn = self.get_thread_connection() | ||||||
client = conn.handle | ||||||
dataset = self.dataset(database, schema, conn) | ||||||
|
||||||
with self.exception_handler('create dataset'): | ||||||
client.create_dataset(dataset, exists_ok=True) | ||||||
# Emulate 'create schema if not exists ...' | ||||||
try: | ||||||
kconvey marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
client.get_dataset(dataset) | ||||||
return | ||||||
except google.api_core.exceptions.NotFound: | ||||||
pass | ||||||
|
||||||
fn = lambda: client.create_dataset(dataset) | ||||||
self._retry_and_handle(msg='create dataset', conn=conn, fn=fn) | ||||||
|
||||||
def _query_and_results(self, client, sql, conn, job_params): | ||||||
"""Query the client and wait for results.""" | ||||||
# Cannot reuse job_config if destination is set and ddl is used | ||||||
job_config = google.cloud.bigquery.QueryJobConfig(**job_params) | ||||||
query_job = client.query(sql, job_config=job_config) | ||||||
iterator = query_job.result(timeout=self.get_timeout(conn)) | ||||||
|
||||||
return query_job, iterator | ||||||
|
||||||
def _retry_and_handle(self, msg, conn, fn): | ||||||
"""retry a function call within the context of exception_handler.""" | ||||||
with self.exception_handler(msg): | ||||||
return retry.retry_target( | ||||||
target=fn, | ||||||
predicate=_ErrorCounter(self.get_retries(conn)).count_error, | ||||||
sleep_generator=self._retry_generator(), | ||||||
deadline=None) | ||||||
|
||||||
def _retry_generator(self): | ||||||
"""Generates retry intervals that exponentially back off.""" | ||||||
return retry.exponential_sleep_generator( | ||||||
initial=self.DEFAULT_INITIAL_DELAY, | ||||||
maximum=self.DEFAULT_MAXIMUM_DELAY) | ||||||
|
||||||
class _ErrorCounter(object): | ||||||
"""Counts errors seen up to a threshold then raises the next error.""" | ||||||
|
||||||
def __init__(self, retries): | ||||||
self.retries = retries | ||||||
self.error_count = 0 | ||||||
|
||||||
def count_error(self, error): | ||||||
if self.retries == 0: | ||||||
return False # Don't log | ||||||
self.error_count +=1 | ||||||
if _is_retryable(error) and self.error_count <= self.retries: | ||||||
logger.warning( | ||||||
'Retry attempt %s of %s after error: %s', | ||||||
self.error_count, self.retries, repr(error)) | ||||||
return True | ||||||
else: | ||||||
logger.warning( | ||||||
'Not Retrying after %s previous attempts. Error: %s', | ||||||
self.error_count - 1, repr(error)) | ||||||
return False | ||||||
|
||||||
def _is_retryable(error): | ||||||
"""Return true for 500 level (retryable) errors.""" | ||||||
return isinstance(error, google.cloud.exceptions.ServerError) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
import unittest | ||
from unittest.mock import patch, MagicMock | ||
from contextlib import contextmanager | ||
from unittest.mock import patch, MagicMock, Mock | ||
|
||
import hologram | ||
|
||
import dbt.flags as flags | ||
|
||
from dbt.adapters.bigquery import BigQueryCredentials | ||
from dbt.adapters.bigquery import BigQueryAdapter | ||
from dbt.adapters.bigquery import BigQueryRelation | ||
from dbt.adapters.bigquery.connections import BigQueryConnectionManager | ||
import dbt.exceptions | ||
from dbt.logger import GLOBAL_LOGGER as logger # noqa | ||
|
||
|
@@ -287,3 +290,81 @@ def test_invalid_relation(self): | |
} | ||
with self.assertRaises(hologram.ValidationError): | ||
BigQueryRelation.from_dict(kwargs) | ||
|
||
|
||
class TestBigQueryConnectionManager(unittest.TestCase): | ||
|
||
def setUp(self): | ||
credentials = Mock(BigQueryCredentials) | ||
credentials.query_comment = None | ||
kconvey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.connections = BigQueryConnectionManager(profile=credentials) | ||
kconvey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.mock_client = Mock( | ||
dbt.adapters.bigquery.impl.google.cloud.bigquery.Client) | ||
self.mock_connection = MagicMock() | ||
|
||
self.mock_connection.handle = self.mock_client | ||
|
||
self.connections.get_thread_connection = lambda: self.mock_connection | ||
|
||
def test_retry_and_handle(self): | ||
self.connections.DEFAULT_MAXIMUM_DELAY = 2.0 | ||
dbt.adapters.bigquery.connections._is_retryable = lambda x: True | ||
|
||
@contextmanager | ||
def dummy_handler(msg): | ||
yield | ||
|
||
self.connections.exception_handler = dummy_handler | ||
|
||
class DummyException(Exception): | ||
"""Count how many times this exception is raised""" | ||
count = 0 | ||
|
||
def __init__(self): | ||
DummyException.count += 1 | ||
|
||
def raiseDummyException(): | ||
raise DummyException() | ||
|
||
# with self.assertLogs(logger.name) as logs: | ||
with self.assertRaises(DummyException): | ||
self.connections._retry_and_handle( | ||
"some sql", {'credentials': {'retries': 8}}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be a mock credentials object now, instead of a dict. Probably something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. |
||
raiseDummyException) | ||
self.assertEqual(DummyException.count, 9) | ||
# self.assertIn( | ||
# 'WARNING:dbt:Retry attempt 1 of 8 after error: DummyException()', | ||
# logs.output) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you remove this commented-out code? You can use pytest's stdout capture stuff if you can get it working in the tests instead, but otherwise I wouldn't bother too much about it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed it. |
||
def test_is_retryable(self): | ||
_is_retryable = dbt.adapters.bigquery.connections._is_retryable | ||
exceptions = dbt.adapters.bigquery.impl.google.cloud.exceptions | ||
internal_server_error = exceptions.InternalServerError('code broke') | ||
bad_request_error = exceptions.BadRequest('code broke') | ||
|
||
self.assertTrue(_is_retryable(internal_server_error)) | ||
self.assertFalse(_is_retryable(bad_request_error)) | ||
|
||
def test_drop_dataset(self): | ||
mock_table = Mock() | ||
mock_table.reference = 'table1' | ||
|
||
self.mock_client.list_tables.return_value = [mock_table] | ||
|
||
self.connections.drop_dataset('project', 'dataset') | ||
|
||
self.mock_client.list_tables.assert_called_once() | ||
kconvey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.mock_client.delete_table.assert_called_once_with('table1') | ||
kconvey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.mock_client.delete_dataset.assert_called_once() | ||
|
||
@patch('dbt.adapters.bigquery.impl.google.cloud.bigquery') | ||
def test_query_and_results(self, mock_bq): | ||
self.connections.get_timeout = lambda x: 100.0 | ||
|
||
self.connections._query_and_results( | ||
self.mock_client, 'sql', self.mock_connection, | ||
{'description': 'blah'}) | ||
|
||
mock_bq.QueryJobConfig.assert_called_once() | ||
self.mock_client.query.assert_called_once_with( | ||
'sql', job_config=mock_bq.QueryJobConfig()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you'll want to change this to just be something like
return conn.credentials.retries
. You can get rid ofRETRIES
then, too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mirrored the syntax for get_timeout since this does a very similar thing.