Skip to content

Commit

Permalink
Merge pull request #1963 from kconvey/feature/retries
Browse files Browse the repository at this point in the history
Implement retries in BQ adapter
  • Loading branch information
beckjake authored Dec 12, 2019
2 parents 2456b6d + 0356a74 commit 9222f80
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 27 deletions.
111 changes: 86 additions & 25 deletions plugins/bigquery/dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import google.oauth2
import google.cloud.exceptions
import google.cloud.bigquery
from google.api_core import retry

import dbt.clients.agate_helper
import dbt.exceptions
Expand Down Expand Up @@ -35,6 +36,7 @@ class BigQueryCredentials(Credentials):
timeout_seconds: Optional[int] = 300
location: Optional[str] = None
priority: Optional[Priority] = None
retries: Optional[int] = 1
_ALIASES = {
'project': 'database',
'dataset': 'schema',
Expand All @@ -57,6 +59,9 @@ class BigQueryConnectionManager(BaseConnectionManager):
'https://www.googleapis.com/auth/drive')

QUERY_TIMEOUT = 300
RETRIES = 1
DEFAULT_INITIAL_DELAY = 1.0 # Seconds
DEFAULT_MAXIMUM_DELAY = 1.0 # Seconds

@classmethod
def handle_error(cls, error, message, sql):
Expand Down Expand Up @@ -170,6 +175,11 @@ def get_timeout(cls, conn):
credentials = conn.credentials
return credentials.timeout_seconds

@classmethod
def get_retries(cls, conn):
credentials = conn.credentials
return credentials.retries

@classmethod
def get_table_from_response(cls, resp):
column_names = [field.name for field in resp.schema]
Expand All @@ -182,21 +192,19 @@ def raw_execute(self, sql, fetch=False):

logger.debug('On {}: {}', conn.name, sql)

job_config = google.cloud.bigquery.QueryJobConfig()
job_config.use_legacy_sql = False
job_params = {'use_legacy_sql': False}

priority = conn.credentials.priority
if priority == Priority.Batch:
job_config.priority = google.cloud.bigquery.QueryPriority.BATCH
job_params['priority'] = google.cloud.bigquery.QueryPriority.BATCH
else:
job_config.priority = \
google.cloud.bigquery.QueryPriority.INTERACTIVE
job_params[
'priority'] = google.cloud.bigquery.QueryPriority.INTERACTIVE

query_job = client.query(sql, job_config)
def fn():
return self._query_and_results(client, sql, conn, job_params)

# this blocks until the query has completed
with self.exception_handler(sql):
iterator = query_job.result()
query_job, iterator = self._retry_and_handle(msg=sql, conn=conn, fn=fn)

return query_job, iterator

Expand Down Expand Up @@ -243,8 +251,9 @@ def create_bigquery_table(self, database, schema, table_name, callback,
view = google.cloud.bigquery.Table(view_ref)
callback(view)

with self.exception_handler(sql):
client.create_table(view)
def fn():
return client.create_table(view)
self._retry_and_handle(msg=sql, conn=conn, fn=fn)

def create_view(self, database, schema, table_name, sql):
def callback(table):
Expand All @@ -258,15 +267,12 @@ def create_table(self, database, schema, table_name, sql):
client = conn.handle

table_ref = self.table_ref(database, schema, table_name, conn)
job_config = google.cloud.bigquery.QueryJobConfig()
job_config.destination = table_ref
job_config.write_disposition = 'WRITE_TRUNCATE'

query_job = client.query(sql, job_config=job_config)
job_params = {'destination': table_ref,
'write_disposition': 'WRITE_TRUNCATE'}

# this waits for the job to complete
with self.exception_handler(sql):
query_job.result(timeout=self.get_timeout(conn))
def fn():
return self._query_and_results(client, sql, conn, job_params)
self._retry_and_handle(msg=sql, conn=conn, fn=fn)

def create_date_partitioned_table(self, database, schema, table_name):
def callback(table):
Expand Down Expand Up @@ -295,15 +301,70 @@ def drop_dataset(self, database, schema):
dataset = self.dataset(database, schema, conn)
client = conn.handle

with self.exception_handler('drop dataset'):
client.delete_dataset(
dataset, delete_contents=True, not_found_ok=True
)
def fn():
return client.delete_dataset(
dataset, delete_contents=True, not_found_ok=True)

self._retry_and_handle(
msg='drop dataset', conn=conn, fn=fn)

def create_dataset(self, database, schema):
conn = self.get_thread_connection()
client = conn.handle
dataset = self.dataset(database, schema, conn)

with self.exception_handler('create dataset'):
client.create_dataset(dataset, exists_ok=True)
def fn():
return client.create_dataset(dataset, exists_ok=True)
self._retry_and_handle(msg='create dataset', conn=conn, fn=fn)

def _query_and_results(self, client, sql, conn, job_params):
"""Query the client and wait for results."""
# Cannot reuse job_config if destination is set and ddl is used
job_config = google.cloud.bigquery.QueryJobConfig(**job_params)
query_job = client.query(sql, job_config=job_config)
iterator = query_job.result(timeout=self.get_timeout(conn))

return query_job, iterator

def _retry_and_handle(self, msg, conn, fn):
"""retry a function call within the context of exception_handler."""
with self.exception_handler(msg):
return retry.retry_target(
target=fn,
predicate=_ErrorCounter(self.get_retries(conn)).count_error,
sleep_generator=self._retry_generator(),
deadline=None)

def _retry_generator(self):
"""Generates retry intervals that exponentially back off."""
return retry.exponential_sleep_generator(
initial=self.DEFAULT_INITIAL_DELAY,
maximum=self.DEFAULT_MAXIMUM_DELAY)


class _ErrorCounter(object):
"""Counts errors seen up to a threshold then raises the next error."""

def __init__(self, retries):
self.retries = retries
self.error_count = 0

def count_error(self, error):
if self.retries == 0:
return False # Don't log
self.error_count += 1
if _is_retryable(error) and self.error_count <= self.retries:
logger.warning(
'Retry attempt %s of %s after error: %s',
self.error_count, self.retries, repr(error))
return True
else:
logger.warning(
'Not Retrying after %s previous attempts. Error: %s',
self.error_count - 1, repr(error))
return False


def _is_retryable(error):
"""Return true for 500 level (retryable) errors."""
return isinstance(error, google.cloud.exceptions.ServerError)
5 changes: 4 additions & 1 deletion plugins/bigquery/dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,15 @@ def list_schemas(self, database: str) -> List[str]:
conn = self.connections.get_thread_connection()
client = conn.handle

with self.connections.exception_handler('list dataset'):
def query_schemas():
# this is similar to how we have to deal with listing tables
all_datasets = client.list_datasets(project=database,
max_results=10000)
return [ds.dataset_id for ds in all_datasets]

return self.connections._retry_and_handle(
msg='list dataset', conn=conn, fn=query_schemas)

@available.parse(lambda *a, **k: False)
def check_schema_exists(self, database: str, schema: str) -> bool:
conn = self.connections.get_thread_connection()
Expand Down
79 changes: 78 additions & 1 deletion test/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import unittest
from unittest.mock import patch, MagicMock
from contextlib import contextmanager
from unittest.mock import patch, MagicMock, Mock

import hologram

import dbt.flags as flags

from dbt.adapters.bigquery import BigQueryCredentials
from dbt.adapters.bigquery import BigQueryAdapter
from dbt.adapters.bigquery import BigQueryRelation
from dbt.adapters.bigquery.connections import BigQueryConnectionManager
import dbt.exceptions
from dbt.logger import GLOBAL_LOGGER as logger # noqa

Expand Down Expand Up @@ -293,3 +296,77 @@ def test_invalid_relation(self):
}
with self.assertRaises(hologram.ValidationError):
BigQueryRelation.from_dict(kwargs)


class TestBigQueryConnectionManager(unittest.TestCase):

def setUp(self):
credentials = Mock(BigQueryCredentials)
profile = Mock(query_comment=None, credentials=credentials)
self.connections = BigQueryConnectionManager(profile=profile)
self.mock_client = Mock(
dbt.adapters.bigquery.impl.google.cloud.bigquery.Client)
self.mock_connection = MagicMock()

self.mock_connection.handle = self.mock_client

self.connections.get_thread_connection = lambda: self.mock_connection

def test_retry_and_handle(self):
self.connections.DEFAULT_MAXIMUM_DELAY = 2.0
dbt.adapters.bigquery.connections._is_retryable = lambda x: True

@contextmanager
def dummy_handler(msg):
yield

self.connections.exception_handler = dummy_handler

class DummyException(Exception):
"""Count how many times this exception is raised"""
count = 0

def __init__(self):
DummyException.count += 1

def raiseDummyException():
raise DummyException()

with self.assertRaises(DummyException):
self.connections._retry_and_handle(
"some sql", Mock(credentials=Mock(retries=8)),
raiseDummyException)
self.assertEqual(DummyException.count, 9)

def test_is_retryable(self):
_is_retryable = dbt.adapters.bigquery.connections._is_retryable
exceptions = dbt.adapters.bigquery.impl.google.cloud.exceptions
internal_server_error = exceptions.InternalServerError('code broke')
bad_request_error = exceptions.BadRequest('code broke')

self.assertTrue(_is_retryable(internal_server_error))
self.assertFalse(_is_retryable(bad_request_error))

def test_drop_dataset(self):
mock_table = Mock()
mock_table.reference = 'table1'

self.mock_client.list_tables.return_value = [mock_table]

self.connections.drop_dataset('project', 'dataset')

self.mock_client.list_tables.assert_not_called()
self.mock_client.delete_table.assert_not_called()
self.mock_client.delete_dataset.assert_called_once()

@patch('dbt.adapters.bigquery.impl.google.cloud.bigquery')
def test_query_and_results(self, mock_bq):
self.connections.get_timeout = lambda x: 100.0

self.connections._query_and_results(
self.mock_client, 'sql', self.mock_connection,
{'description': 'blah'})

mock_bq.QueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
'sql', job_config=mock_bq.QueryJobConfig())

0 comments on commit 9222f80

Please sign in to comment.