Skip to content

Commit

Permalink
[AIRFLOW-4335] Add default num_retries to GCP connection (#5117)
Browse files Browse the repository at this point in the history
Add default num_retries to GCP connection
  • Loading branch information
ryanyuan authored and potiuk committed Apr 20, 2019
1 parent d3d417d commit 16e7e61
Show file tree
Hide file tree
Showing 18 changed files with 201 additions and 140 deletions.
63 changes: 35 additions & 28 deletions airflow/contrib/hooks/bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self,
gcp_conn_id=bigquery_conn_id, delegate_to=delegate_to)
self.use_legacy_sql = use_legacy_sql
self.location = location
self.num_retries = self._get_field('num_retries', 5)

def get_conn(self):
"""
Expand All @@ -72,6 +73,7 @@ def get_conn(self):
project_id=project,
use_legacy_sql=self.use_legacy_sql,
location=self.location,
num_retries=self.num_retries
)

def get_service(self):
Expand Down Expand Up @@ -137,7 +139,7 @@ def table_exists(self, project_id, dataset_id, table_id):
try:
service.tables().get(
projectId=project_id, datasetId=dataset_id,
tableId=table_id).execute()
tableId=table_id).execute(num_retries=self.num_retries)
return True
except HttpError as e:
if e.resp['status'] == '404':
Expand Down Expand Up @@ -210,7 +212,8 @@ def __init__(self,
project_id,
use_legacy_sql=True,
api_resource_configs=None,
location=None):
location=None,
num_retries=None):

self.service = service
self.project_id = project_id
Expand All @@ -221,6 +224,7 @@ def __init__(self,
if api_resource_configs else {}
self.running_job_id = None
self.location = location
self.num_retries = num_retries

def create_empty_table(self,
project_id,
Expand All @@ -231,7 +235,7 @@ def create_empty_table(self,
cluster_fields=None,
labels=None,
view=None,
num_retries=5):
num_retries=None):
"""
Creates a new, empty table in the dataset.
To create a view, which is defined by a SQL query, parse a dictionary to 'view' kwarg
Expand Down Expand Up @@ -304,6 +308,8 @@ def create_empty_table(self,
if view:
table_resource['view'] = view

num_retries = num_retries if num_retries else self.num_retries

self.log.info('Creating Table %s:%s.%s',
project_id, dataset_id, table_id)

Expand Down Expand Up @@ -507,7 +513,7 @@ def create_external_table(self,
projectId=project_id,
datasetId=dataset_id,
body=table_resource
).execute()
).execute(num_retries=self.num_retries)

self.log.info('External table created successfully: %s',
external_project_dataset_table)
Expand Down Expand Up @@ -615,7 +621,7 @@ def patch_table(self,
projectId=project_id,
datasetId=dataset_id,
tableId=table_id,
body=table_resource).execute()
body=table_resource).execute(num_retries=self.num_retries)

self.log.info('Table patched successfully: %s:%s.%s',
project_id, dataset_id, table_id)
Expand Down Expand Up @@ -1221,7 +1227,7 @@ def run_with_configuration(self, configuration):
# Send query and wait for reply.
query_reply = jobs \
.insert(projectId=self.project_id, body=job_data) \
.execute()
.execute(num_retries=self.num_retries)
self.running_job_id = query_reply['jobReference']['jobId']
if 'location' in query_reply['jobReference']:
location = query_reply['jobReference']['location']
Expand All @@ -1236,11 +1242,11 @@ def run_with_configuration(self, configuration):
job = jobs.get(
projectId=self.project_id,
jobId=self.running_job_id,
location=location).execute()
location=location).execute(num_retries=self.num_retries)
else:
job = jobs.get(
projectId=self.project_id,
jobId=self.running_job_id).execute()
jobId=self.running_job_id).execute(num_retries=self.num_retries)
if job['status']['state'] == 'DONE':
keep_polling_job = False
# Check if job had errors.
Expand Down Expand Up @@ -1272,10 +1278,10 @@ def poll_job_complete(self, job_id):
if self.location:
job = jobs.get(projectId=self.project_id,
jobId=job_id,
location=self.location).execute()
location=self.location).execute(num_retries=self.num_retries)
else:
job = jobs.get(projectId=self.project_id,
jobId=job_id).execute()
jobId=job_id).execute(num_retries=self.num_retries)
if job['status']['state'] == 'DONE':
return True
except HttpError as err:
Expand All @@ -1302,11 +1308,11 @@ def cancel_query(self):
jobs.cancel(
projectId=self.project_id,
jobId=self.running_job_id,
location=self.location).execute()
location=self.location).execute(num_retries=self.num_retries)
else:
jobs.cancel(
projectId=self.project_id,
jobId=self.running_job_id).execute()
jobId=self.running_job_id).execute(num_retries=self.num_retries)
else:
self.log.info('No running BigQuery jobs to cancel.')
return
Expand Down Expand Up @@ -1343,7 +1349,7 @@ def get_schema(self, dataset_id, table_id):
"""
tables_resource = self.service.tables() \
.get(projectId=self.project_id, datasetId=dataset_id, tableId=table_id) \
.execute()
.execute(num_retries=self.num_retries)
return tables_resource['schema']

def get_tabledata(self, dataset_id, table_id,
Expand Down Expand Up @@ -1376,7 +1382,7 @@ def get_tabledata(self, dataset_id, table_id,
projectId=self.project_id,
datasetId=dataset_id,
tableId=table_id,
**optional_params).execute())
**optional_params).execute(num_retries=self.num_retries))

def run_table_delete(self, deletion_dataset_table,
ignore_if_missing=False):
Expand All @@ -1403,7 +1409,7 @@ def run_table_delete(self, deletion_dataset_table,
.delete(projectId=deletion_project,
datasetId=deletion_dataset,
tableId=deletion_table) \
.execute()
.execute(num_retries=self.num_retries)
self.log.info('Deleted table %s:%s.%s.', deletion_project,
deletion_dataset, deletion_table)
except HttpError:
Expand Down Expand Up @@ -1432,7 +1438,7 @@ def run_table_upsert(self, dataset_id, table_resource, project_id=None):
table_id = table_resource['tableReference']['tableId']
project_id = project_id if project_id is not None else self.project_id
tables_list_resp = self.service.tables().list(
projectId=project_id, datasetId=dataset_id).execute()
projectId=project_id, datasetId=dataset_id).execute(num_retries=self.num_retries)
while True:
for table in tables_list_resp.get('tables', []):
if table['tableReference']['tableId'] == table_id:
Expand All @@ -1443,14 +1449,14 @@ def run_table_upsert(self, dataset_id, table_resource, project_id=None):
projectId=project_id,
datasetId=dataset_id,
tableId=table_id,
body=table_resource).execute()
body=table_resource).execute(num_retries=self.num_retries)
# If there is a next page, we need to check the next page.
if 'nextPageToken' in tables_list_resp:
tables_list_resp = self.service.tables()\
.list(projectId=project_id,
datasetId=dataset_id,
pageToken=tables_list_resp['nextPageToken'])\
.execute()
.execute(num_retries=self.num_retries)
# If there is no next page, then the table doesn't exist.
else:
# do insert
Expand All @@ -1459,7 +1465,7 @@ def run_table_upsert(self, dataset_id, table_resource, project_id=None):
return self.service.tables().insert(
projectId=project_id,
datasetId=dataset_id,
body=table_resource).execute()
body=table_resource).execute(num_retries=self.num_retries)

def run_grant_dataset_view_access(self,
source_dataset,
Expand Down Expand Up @@ -1494,7 +1500,7 @@ def run_grant_dataset_view_access(self,
# we don't want to clobber any existing accesses, so we have to get
# info on the dataset before we can add view access
source_dataset_resource = self.service.datasets().get(
projectId=source_project, datasetId=source_dataset).execute()
projectId=source_project, datasetId=source_dataset).execute(num_retries=self.num_retries)
access = source_dataset_resource[
'access'] if 'access' in source_dataset_resource else []
view_access = {
Expand All @@ -1516,7 +1522,7 @@ def run_grant_dataset_view_access(self,
datasetId=source_dataset,
body={
'access': access
}).execute()
}).execute(num_retries=self.num_retries)
else:
# if view is already in access, do nothing.
self.log.info(
Expand Down Expand Up @@ -1582,7 +1588,7 @@ def create_empty_dataset(self, dataset_id="", project_id="",
try:
self.service.datasets().insert(
projectId=dataset_project_id,
body=dataset_reference).execute()
body=dataset_reference).execute(num_retries=self.num_retries)
self.log.info('Dataset created successfully: In project %s '
'Dataset %s', dataset_project_id, dataset_id)

Expand All @@ -1607,7 +1613,7 @@ def delete_dataset(self, project_id, dataset_id):
try:
self.service.datasets().delete(
projectId=project_id,
datasetId=dataset_id).execute()
datasetId=dataset_id).execute(num_retries=self.num_retries)
self.log.info('Dataset deleted successfully: In project %s '
'Dataset %s', project_id, dataset_id)

Expand Down Expand Up @@ -1640,7 +1646,7 @@ def get_dataset(self, dataset_id, project_id=None):

try:
dataset_resource = self.service.datasets().get(
datasetId=dataset_id, projectId=dataset_project_id).execute()
datasetId=dataset_id, projectId=dataset_project_id).execute(num_retries=self.num_retries)
self.log.info("Dataset Resource: %s", dataset_resource)
except HttpError as err:
raise AirflowException(
Expand Down Expand Up @@ -1687,7 +1693,7 @@ def get_datasets_list(self, project_id=None):

try:
datasets_list = self.service.datasets().list(
projectId=dataset_project_id).execute()['datasets']
projectId=dataset_project_id).execute(num_retries=self.num_retries)['datasets']
self.log.info("Datasets List: %s", datasets_list)

except HttpError as err:
Expand Down Expand Up @@ -1751,7 +1757,7 @@ def insert_all(self, project_id, dataset_id, table_id,
resp = self.service.tabledata().insertAll(
projectId=dataset_project_id, datasetId=dataset_id,
tableId=table_id, body=body
).execute()
).execute(num_retries=self.num_retries)

if 'insertErrors' not in resp:
self.log.info(
Expand Down Expand Up @@ -1782,12 +1788,13 @@ class BigQueryCursor(BigQueryBaseCursor):
https://github.com/dropbox/PyHive/blob/master/pyhive/common.py
"""

def __init__(self, service, project_id, use_legacy_sql=True, location=None):
def __init__(self, service, project_id, use_legacy_sql=True, location=None, num_retries=None):
super(BigQueryCursor, self).__init__(
service=service,
project_id=project_id,
use_legacy_sql=use_legacy_sql,
location=location,
num_retries=num_retries
)
self.buffersize = None
self.page_token = None
Expand Down Expand Up @@ -1855,7 +1862,7 @@ def next(self):
query_results = (self.service.jobs().getQueryResults(
projectId=self.project_id,
jobId=self.job_id,
pageToken=self.page_token).execute())
pageToken=self.page_token).execute(num_retries=self.num_retries))

if 'rows' in query_results and query_results['rows']:
self.page_token = query_results.get('pageToken')
Expand Down
Loading

0 comments on commit 16e7e61

Please sign in to comment.