Skip to content
Merged
126 changes: 75 additions & 51 deletions google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import logging
import queue
import threading
import time
import warnings
from typing import Any, Union, Optional, Callable, Generator, List

Expand Down Expand Up @@ -869,6 +870,7 @@ def _download_table_bqstorage(
max_queue_size: Any = _MAX_QUEUE_SIZE_DEFAULT,
max_stream_count: Optional[int] = None,
download_state: Optional[_DownloadState] = None,
timeout: Optional[float] = None,
) -> Generator[Any, None, None]:
"""Downloads a BigQuery table using the BigQuery Storage API.

Expand Down Expand Up @@ -899,13 +901,18 @@ def _download_table_bqstorage(
download_state (Optional[_DownloadState]):
A threadsafe state object which can be used to observe the
behavior of the worker threads created by this method.
timeout (Optional[float]):
The number of seconds to wait for the download to complete.
If None, wait indefinitely.

Yields:
pandas.DataFrame: Pandas DataFrames, one for each chunk of data
downloaded from BigQuery.

Raises:
ValueError: If attempting to read from a specific partition or snapshot.
concurrent.futures.TimeoutError:
If the download does not complete within the specified timeout.

Note:
This method requires the `google-cloud-bigquery-storage` library
Expand Down Expand Up @@ -973,60 +980,73 @@ def _download_table_bqstorage(

worker_queue: queue.Queue[int] = queue.Queue(maxsize=max_queue_size)

with concurrent.futures.ThreadPoolExecutor(max_workers=total_streams) as pool:
try:
# Manually submit jobs and wait for download to complete rather
# than using pool.map because pool.map continues running in the
# background even if there is an exception on the main thread.
# See: https://github.com/googleapis/google-cloud-python/pull/7698
not_done = [
pool.submit(
_download_table_bqstorage_stream,
download_state,
bqstorage_client,
session,
stream,
worker_queue,
page_to_item,
)
for stream in session.streams
]

while not_done:
# Don't block on the worker threads. For performance reasons,
# we want to block on the queue's get method, instead. This
# prevents the queue from filling up, because the main thread
# has smaller gaps in time between calls to the queue's get
# method. For a detailed explanation, see:
# https://friendliness.dev/2019/06/18/python-nowait/
done, not_done = _nowait(not_done)
for future in done:
# Call result() on any finished threads to raise any
# exceptions encountered.
future.result()
# Manually manage the pool to control shutdown behavior on timeout.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may wonder why we switched away from using a contextmanager here.

Here's a breakdown of the change:

  1. Manual Pool Creation: Instead of the with statement, the code now creates the
    ThreadPoolExecutor directly:

    pool = concurrent.futures.ThreadPoolExecutor(max_workers=max(1, total_streams))

  2. Conditional Shutdown: The finally block now uses a wait_on_shutdown boolean flag:

    pool.shutdown(wait=wait_on_shutdown)
    This flag is set to False only when a TimeoutError is raised.

  3. Timeout Handling: Inside the while not_done: loop, there's new logic to check if the
    elapsed time has exceeded the specified timeout. If it has, it raises a
    concurrent.futures.TimeoutError and sets wait_on_shutdown to False.

Why this change?

The standard context manager for ThreadPoolExecutor always waits for all futures to complete
upon exiting the block. This behavior is not desirable when a timeout is implemented. If the
download times out, we want to stop waiting for the worker threads immediately and not block
until they all finish.

By manually managing the pool and using the wait_on_shutdown flag, we can tell the
pool.shutdown() method not to wait for the threads to complete if a timeout has occurred.
This allows the TimeoutError to be propagated up quickly, rather than being stuck waiting for
threads that are potentially hanging.

So, the change was necessary to ensure the timeout parameter works effectively and the
function doesn't hang unnecessarily when the timeout duration is exceeded. The core logic of
submitting tasks to the pool and retrieving results from the queue remains very similar, but
the shutdown process is now more nuanced to handle timeouts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

pool = concurrent.futures.ThreadPoolExecutor(max_workers=max(1, total_streams))
wait_on_shutdown = True
start_time = time.time()

try:
frame = worker_queue.get(timeout=_PROGRESS_INTERVAL)
yield frame
except queue.Empty: # pragma: NO COVER
continue
try:
# Manually submit jobs and wait for download to complete rather
# than using pool.map because pool.map continues running in the
# background even if there is an exception on the main thread.
# See: https://github.com/googleapis/google-cloud-python/pull/7698
not_done = [
pool.submit(
_download_table_bqstorage_stream,
download_state,
bqstorage_client,
session,
stream,
worker_queue,
page_to_item,
)
for stream in session.streams
]

while not_done:
# Check for timeout
if timeout is not None:
elapsed = time.time() - start_time
if elapsed > timeout:
wait_on_shutdown = False
raise concurrent.futures.TimeoutError(
f"Download timed out after {timeout} seconds."
)

# Don't block on the worker threads. For performance reasons,
# we want to block on the queue's get method, instead. This
# prevents the queue from filling up, because the main thread
# has smaller gaps in time between calls to the queue's get
# method. For a detailed explanation, see:
# https://friendliness.dev/2019/06/18/python-nowait/
done, not_done = _nowait(not_done)
for future in done:
# Call result() on any finished threads to raise any
# exceptions encountered.
future.result()

try:
frame = worker_queue.get(timeout=_PROGRESS_INTERVAL)
yield frame
except queue.Empty: # pragma: NO COVER
continue

# Return any remaining values after the workers finished.
while True: # pragma: NO COVER
try:
frame = worker_queue.get_nowait()
yield frame
except queue.Empty: # pragma: NO COVER
break
finally:
# No need for a lock because reading/replacing a variable is
# defined to be an atomic operation in the Python language
# definition (enforced by the global interpreter lock).
download_state.done = True
# Return any remaining values after the workers finished.
while True: # pragma: NO COVER
try:
frame = worker_queue.get_nowait()
yield frame
except queue.Empty: # pragma: NO COVER
break
finally:
# No need for a lock because reading/replacing a variable is
# defined to be an atomic operation in the Python language
# definition (enforced by the global interpreter lock).
download_state.done = True

# Shutdown all background threads, now that they should know to
# exit early.
pool.shutdown(wait=True)
# Shutdown all background threads, now that they should know to
# exit early.
pool.shutdown(wait=wait_on_shutdown)


def download_arrow_bqstorage(
Expand All @@ -1037,6 +1057,7 @@ def download_arrow_bqstorage(
selected_fields=None,
max_queue_size=_MAX_QUEUE_SIZE_DEFAULT,
max_stream_count=None,
timeout=None,
):
return _download_table_bqstorage(
project_id,
Expand All @@ -1047,6 +1068,7 @@ def download_arrow_bqstorage(
page_to_item=_bqstorage_page_to_arrow,
max_queue_size=max_queue_size,
max_stream_count=max_stream_count,
timeout=timeout,
)


Expand All @@ -1060,6 +1082,7 @@ def download_dataframe_bqstorage(
selected_fields=None,
max_queue_size=_MAX_QUEUE_SIZE_DEFAULT,
max_stream_count=None,
timeout=None,
):
page_to_item = functools.partial(_bqstorage_page_to_dataframe, column_names, dtypes)
return _download_table_bqstorage(
Expand All @@ -1071,6 +1094,7 @@ def download_dataframe_bqstorage(
page_to_item=page_to_item,
max_queue_size=max_queue_size,
max_stream_count=max_stream_count,
timeout=timeout,
)


Expand Down
17 changes: 17 additions & 0 deletions google/cloud/bigquery/job/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,6 +1857,7 @@ def to_arrow(
bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None,
create_bqstorage_client: bool = True,
max_results: Optional[int] = None,
timeout: Optional[float] = None,
) -> "pyarrow.Table":
"""[Beta] Create a class:`pyarrow.Table` by loading all pages of a
table or query.
Expand Down Expand Up @@ -1904,6 +1905,10 @@ def to_arrow(

.. versionadded:: 2.21.0

timeout (Optional[float]):
The number of seconds to wait for the underlying download to complete.
If ``None``, wait indefinitely.

Returns:
pyarrow.Table
A :class:`pyarrow.Table` populated with row data and column
Expand All @@ -1921,6 +1926,7 @@ def to_arrow(
progress_bar_type=progress_bar_type,
bqstorage_client=bqstorage_client,
create_bqstorage_client=create_bqstorage_client,
timeout=timeout,
)

# If changing the signature of this method, make sure to apply the same
Expand Down Expand Up @@ -1949,6 +1955,7 @@ def to_dataframe(
range_timestamp_dtype: Union[
Any, None
] = DefaultPandasDTypes.RANGE_TIMESTAMP_DTYPE,
timeout: Optional[float] = None,
) -> "pandas.DataFrame":
"""Return a pandas DataFrame from a QueryJob

Expand Down Expand Up @@ -2141,6 +2148,10 @@ def to_dataframe(

.. versionadded:: 3.21.0

timeout (Optional[float]):
The number of seconds to wait for the underlying download to complete.
If ``None``, wait indefinitely.

Returns:
pandas.DataFrame:
A :class:`~pandas.DataFrame` populated with row data
Expand Down Expand Up @@ -2174,6 +2185,7 @@ def to_dataframe(
range_date_dtype=range_date_dtype,
range_datetime_dtype=range_datetime_dtype,
range_timestamp_dtype=range_timestamp_dtype,
timeout=timeout,
)

# If changing the signature of this method, make sure to apply the same
Expand All @@ -2191,6 +2203,7 @@ def to_geodataframe(
int_dtype: Union[Any, None] = DefaultPandasDTypes.INT_DTYPE,
float_dtype: Union[Any, None] = None,
string_dtype: Union[Any, None] = None,
timeout: Optional[float] = None,
) -> "geopandas.GeoDataFrame":
"""Return a GeoPandas GeoDataFrame from a QueryJob

Expand Down Expand Up @@ -2269,6 +2282,9 @@ def to_geodataframe(
then the data type will be ``numpy.dtype("object")``. BigQuery String
type can be found at:
https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#string_type
timeout (Optional[float]):
The number of seconds to wait for the underlying download to complete.
If ``None``, wait indefinitely.

Returns:
geopandas.GeoDataFrame:
Expand Down Expand Up @@ -2296,6 +2312,7 @@ def to_geodataframe(
int_dtype=int_dtype,
float_dtype=float_dtype,
string_dtype=string_dtype,
timeout=timeout,
)

def __iter__(self):
Expand Down
Loading
Loading