Skip to content

Commit

Permalink
Split timeout in multi-request methods
Browse files Browse the repository at this point in the history
If a method makes multiple requests and is given a timeout, that
timeout should represent the total allowed time for all requests
combined.
  • Loading branch information
plamut committed Dec 20, 2019
1 parent a5188cb commit 9e0abcf
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 22 deletions.
35 changes: 24 additions & 11 deletions bigquery/google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
except ImportError: # Python 2.7
import collections as collections_abc

import concurrent.futures
import copy
import functools
import gzip
Expand All @@ -47,6 +48,7 @@
import google.api_core.client_options
import google.api_core.exceptions
from google.api_core import page_iterator
from google.auth.transport.requests import TimeoutGuard
import google.cloud._helpers
from google.cloud import exceptions
from google.cloud.client import ClientWithProject
Expand Down Expand Up @@ -2557,21 +2559,27 @@ def list_partitions(self, table, retry=DEFAULT_RETRY, timeout=None):
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
Returns:
List[str]:
A list of the partition ids present in the partitioned table
"""
# TODO: split timeout between all API calls in the method
table = _table_arg_to_table_ref(table, default_project=self.project)
meta_table = self.get_table(
TableReference(
self.dataset(table.dataset_id, project=table.project),
"%s$__PARTITIONS_SUMMARY__" % table.table_id,
),
retry=retry,
timeout=timeout,
)

with TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
) as guard:
meta_table = self.get_table(
TableReference(
self.dataset(table.dataset_id, project=table.project),
"%s$__PARTITIONS_SUMMARY__" % table.table_id,
),
retry=retry,
timeout=timeout,
)
timeout = guard.remaining_timeout

subset = [col for col in meta_table.schema if col.name == "partition_id"]
return [
Expand Down Expand Up @@ -2638,6 +2646,8 @@ def list_rows(
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
Returns:
google.cloud.bigquery.table.RowIterator:
Expand All @@ -2648,7 +2658,6 @@ def list_rows(
(this is distinct from the total number of rows in the
current page: ``iterator.page.num_items``).
"""
# TODO: split timeout between all internal API calls
table = _table_arg_to_table(table, default_project=self.project)

if not isinstance(table, Table):
Expand All @@ -2663,7 +2672,11 @@ def list_rows(
# No schema, but no selected_fields. Assume the developer wants all
# columns, so get the table resource for them rather than failing.
elif len(schema) == 0:
table = self.get_table(table.reference, retry=retry, timeout=timeout)
with TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
) as guard:
table = self.get_table(table.reference, retry=retry, timeout=timeout)
timeout = guard.remaining_timeout
schema = table.schema

params = {}
Expand Down
42 changes: 32 additions & 10 deletions bigquery/google/cloud/bigquery/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from six.moves import http_client

import google.api_core.future.polling
from google.auth.transport.requests import TimeoutGuard
from google.cloud import exceptions
from google.cloud.exceptions import NotFound
from google.cloud.bigquery.dataset import Dataset
Expand Down Expand Up @@ -793,6 +794,8 @@ def result(self, retry=DEFAULT_RETRY, timeout=None):
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
Returns:
_AsyncJob: This instance.
Expand All @@ -803,10 +806,12 @@ def result(self, retry=DEFAULT_RETRY, timeout=None):
concurrent.futures.TimeoutError:
if the job did not complete in the given timeout.
"""
# TODO: combine _begin timeout with super().result() timeout!
# borrow timeout guard from google auth lib
if self.state is None:
self._begin(retry=retry, timeout=timeout)
with TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
) as guard:
self._begin(retry=retry, timeout=timeout)
timeout = guard.remaining_timeout
# TODO: modify PollingFuture so it can pass a retry argument to done().
return super(_AsyncJob, self).result(timeout=timeout)

Expand Down Expand Up @@ -3163,6 +3168,8 @@ def result(
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
Returns:
google.cloud.bigquery.table.RowIterator:
Expand All @@ -3180,16 +3187,27 @@ def result(
If the job did not complete in the given timeout.
"""
try:
# TODO: combine timeout with timeouts passed to super().result()
# and _get_query_results (total timeout shared by both)
# borrow timeout guard from google auth lib
super(QueryJob, self).result(timeout=timeout)
guard = TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
)
with guard:
super(QueryJob, self).result(retry=retry, timeout=timeout)
timeout = guard.remaining_timeout

# Return an iterator instead of returning the job.
if not self._query_results:
self._query_results = self._client._get_query_results(
self.job_id, retry, project=self.project, location=self.location
guard = TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
)
with guard:
self._query_results = self._client._get_query_results(
self.job_id,
retry,
project=self.project,
location=self.location,
timeout=timeout,
)
timeout = guard.remaining_timeout
except exceptions.GoogleCloudError as exc:
exc.message += self._format_for_exception(self.query, self.job_id)
exc.query_job = self
Expand All @@ -3209,7 +3227,11 @@ def result(
dest_table = Table(dest_table_ref, schema=schema)
dest_table._properties["numRows"] = self._query_results.total_rows
rows = self._client.list_rows(
dest_table, page_size=page_size, retry=retry, max_results=max_results
dest_table,
page_size=page_size,
max_results=max_results,
retry=retry,
timeout=timeout,
)
rows._preserve_order = _contains_order_by(self.query)
return rows
Expand Down
4 changes: 3 additions & 1 deletion bigquery/tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import textwrap
import unittest

import freezegun
import mock
import pytest
import requests
Expand Down Expand Up @@ -4551,7 +4552,8 @@ def test_result_w_timeout(self):
client = _make_client(project=self.PROJECT, connection=connection)
job = self._make_one(self.JOB_ID, self.QUERY, client)

job.result(timeout=1.0)
with freezegun.freeze_time("1970-01-01 00:00:00", tick=False):
job.result(timeout=1.0)

self.assertEqual(len(connection.api_request.call_args_list), 3)
begin_request = connection.api_request.call_args_list[0]
Expand Down

0 comments on commit 9e0abcf

Please sign in to comment.