From bfdeb3fdbc1d5b26fcd3d1433abfb0be49d12018 Mon Sep 17 00:00:00 2001 From: Lingqing Gan Date: Mon, 10 Jun 2024 11:49:02 -0700 Subject: [PATCH] feat: add prefer_bqstorage_client option for Connection (#1945) --- google/cloud/bigquery/dbapi/connection.py | 30 +++++++++++++++-------- tests/unit/test_dbapi_connection.py | 20 +++++++++++++++ 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/google/cloud/bigquery/dbapi/connection.py b/google/cloud/bigquery/dbapi/connection.py index 66dee7dfb..a1a69b8fe 100644 --- a/google/cloud/bigquery/dbapi/connection.py +++ b/google/cloud/bigquery/dbapi/connection.py @@ -35,12 +35,18 @@ class Connection(object): A client that uses the faster BigQuery Storage API to fetch rows from BigQuery. If not passed, it is created using the same credentials as ``client`` (provided that BigQuery Storage dependencies are installed). - - If both clients are available, ``bqstorage_client`` is used for - fetching query results. + prefer_bqstorage_client (Optional[bool]): + Prefer the BigQuery Storage client over the REST client. If Storage + client isn't available, fall back to the REST client. Defaults to + ``True``. """ - def __init__(self, client=None, bqstorage_client=None): + def __init__( + self, + client=None, + bqstorage_client=None, + prefer_bqstorage_client=True, + ): if client is None: client = bigquery.Client() self._owns_client = True @@ -49,7 +55,10 @@ def __init__(self, client=None, bqstorage_client=None): # A warning is already raised by the BQ Storage client factory factory if # instantiation fails, or if the given BQ Storage client instance is outdated. - if bqstorage_client is None: + if not prefer_bqstorage_client: + bqstorage_client = None + self._owns_bqstorage_client = False + elif bqstorage_client is None: bqstorage_client = client._ensure_bqstorage_client() self._owns_bqstorage_client = bqstorage_client is not None else: @@ -95,7 +104,7 @@ def cursor(self): return new_cursor -def connect(client=None, bqstorage_client=None): +def connect(client=None, bqstorage_client=None, prefer_bqstorage_client=True): """Construct a DB-API connection to Google BigQuery. Args: @@ -108,11 +117,12 @@ def connect(client=None, bqstorage_client=None): A client that uses the faster BigQuery Storage API to fetch rows from BigQuery. If not passed, it is created using the same credentials as ``client`` (provided that BigQuery Storage dependencies are installed). - - If both clients are available, ``bqstorage_client`` is used for - fetching query results. + prefer_bqstorage_client (Optional[bool]): + Prefer the BigQuery Storage client over the REST client. If Storage + client isn't available, fall back to the REST client. Defaults to + ``True``. Returns: google.cloud.bigquery.dbapi.Connection: A new DB-API connection to BigQuery. """ - return Connection(client, bqstorage_client) + return Connection(client, bqstorage_client, prefer_bqstorage_client) diff --git a/tests/unit/test_dbapi_connection.py b/tests/unit/test_dbapi_connection.py index 4071e57e0..f5c77c448 100644 --- a/tests/unit/test_dbapi_connection.py +++ b/tests/unit/test_dbapi_connection.py @@ -122,6 +122,26 @@ def test_connect_w_both_clients(self): self.assertIs(connection._client, mock_client) self.assertIs(connection._bqstorage_client, mock_bqstorage_client) + def test_connect_prefer_bqstorage_client_false(self): + pytest.importorskip("google.cloud.bigquery_storage") + from google.cloud.bigquery.dbapi import connect + from google.cloud.bigquery.dbapi import Connection + + mock_client = self._mock_client() + mock_bqstorage_client = self._mock_bqstorage_client() + mock_client._ensure_bqstorage_client.return_value = mock_bqstorage_client + + connection = connect( + client=mock_client, + bqstorage_client=mock_bqstorage_client, + prefer_bqstorage_client=False, + ) + + mock_client._ensure_bqstorage_client.assert_not_called() + self.assertIsInstance(connection, Connection) + self.assertIs(connection._client, mock_client) + self.assertIs(connection._bqstorage_client, None) + def test_raises_error_if_closed(self): from google.cloud.bigquery.dbapi.exceptions import ProgrammingError