Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: accepts a table ID, which downloads the table without a query #443

Merged
merged 22 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 75 additions & 39 deletions pandas_gbq/gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# license that can be found in the LICENSE file.

import logging
import re
import time
import warnings
from datetime import datetime
Expand Down Expand Up @@ -64,6 +65,10 @@ def _test_google_api_imports():
raise ImportError("pandas-gbq requires google-cloud-bigquery") from ex


def _is_query(query_or_table: str) -> bool:
return re.search(r"\s", query_or_table.strip(), re.MULTILINE) is not None


class DatasetCreationError(ValueError):
"""
Raised when the create dataset method fails
Expand Down Expand Up @@ -374,6 +379,28 @@ def process_http_error(ex):

raise GenericGBQException("Reason: {0}".format(ex))

def download_table(
self, table_id, max_results=None, progress_bar_type=None, dtypes=None
):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Maybe completely annotate new methods?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, done.

I probably could have used Dict[str, Any] like we did here for dtypes, but maybe something more specific would be better?

Note: we aren't currently running any sort of type checking in this repo. There's an open issue for it here: #325

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a start, Dict[str, any] is probably just fine, especially if type checks are not implemented yet. Let's just make sure it's consistent with the docstring, the latter currently says Optional[Map[str, Union[str, pandas.Series.dtype]]].

self._start_timer()

try:
# Get the table schema, so that we can list rows.
table_ref = bigquery.TableReference.from_string(
table_id, default_project=self.project_id
)
destination = self.client.get_table(table_ref)
rows_iter = self.client.list_rows(destination, max_results=max_results)
except self.http_error as ex:
self.process_http_error(ex)

return self._download_results(
rows_iter,
max_results=max_results,
progress_bar_type=progress_bar_type,
user_dtypes=dtypes,
)

def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs):
from concurrent.futures import TimeoutError
from google.auth.exceptions import RefreshError
Expand All @@ -390,15 +417,6 @@ def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs):
if config is not None:
job_config.update(config)

if "query" in config and "query" in config["query"]:
if query is not None:
raise ValueError(
"Query statement can't be specified "
"inside config while it is specified "
"as parameter"
)
query = config["query"].pop("query")

self._start_timer()

try:
Expand Down Expand Up @@ -464,15 +482,25 @@ def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs):
)

dtypes = kwargs.get("dtypes")

# Ensure destination is populated.
try:
query_reply.result()
except self.http_error as ex:
self.process_http_error(ex)

# Get the table schema, so that we can list rows.
destination = self.client.get_table(query_reply.destination)
rows_iter = self.client.list_rows(destination, max_results=max_results)
return self._download_results(
query_reply,
rows_iter,
max_results=max_results,
progress_bar_type=progress_bar_type,
user_dtypes=dtypes,
)

def _download_results(
self, query_job, max_results=None, progress_bar_type=None, user_dtypes=None,
self, rows_iter, max_results=None, progress_bar_type=None, user_dtypes=None,
):
# No results are desired, so don't bother downloading anything.
if max_results == 0:
Expand Down Expand Up @@ -504,13 +532,10 @@ def _download_results(
to_dataframe_kwargs["create_bqstorage_client"] = create_bqstorage_client

try:
query_job.result()
# Get the table schema, so that we can list rows.
destination = self.client.get_table(query_job.destination)
rows_iter = self.client.list_rows(destination, max_results=max_results)

schema_fields = [field.to_api_repr() for field in rows_iter.schema]
conversion_dtypes = _bqschema_to_nullsafe_dtypes(schema_fields)
# ENDTODO: This is the only difference between table ID and

tswast marked this conversation as resolved.
Show resolved Hide resolved
conversion_dtypes.update(user_dtypes)
df = rows_iter.to_dataframe(
dtypes=conversion_dtypes,
Expand Down Expand Up @@ -644,7 +669,7 @@ def _cast_empty_df_dtypes(schema_fields, df):


def read_gbq(
query,
query_or_table,
project_id=None,
index_col=None,
col_order=None,
Expand All @@ -668,17 +693,18 @@ def read_gbq(

This method uses the Google Cloud client library to make requests to
Google BigQuery, documented `here
<https://google-cloud-python.readthedocs.io/en/latest/bigquery/usage.html>`__.
<https://googleapis.dev/python/bigquery/latest/index.html>`__.

See the :ref:`How to authenticate with Google BigQuery <authentication>`
guide for authentication instructions.

Parameters
----------
query : str
SQL-Like Query to return data values.
query_or_table : str
SQL query to return data values. If the string is a table ID, fetch the
rows directly from the table without running a query.
project_id : str, optional
Google BigQuery Account project ID. Optional when available from
Google Cloud Platform project ID. Optional when available from
the environment.
index_col : str, optional
Name of result column to use for index in results DataFrame.
Expand All @@ -693,9 +719,9 @@ def read_gbq(
when getting user credentials.

.. _local webserver flow:
http://google-auth-oauthlib.readthedocs.io/en/latest/reference/google_auth_oauthlib.flow.html#google_auth_oauthlib.flow.InstalledAppFlow.run_local_server
https://googleapis.dev/python/google-auth-oauthlib/latest/reference/google_auth_oauthlib.flow.html#google_auth_oauthlib.flow.InstalledAppFlow.run_local_server
.. _console flow:
http://google-auth-oauthlib.readthedocs.io/en/latest/reference/google_auth_oauthlib.flow.html#google_auth_oauthlib.flow.InstalledAppFlow.run_console
https://googleapis.dev/python/google-auth-oauthlib/latest/reference/google_auth_oauthlib.flow.html#google_auth_oauthlib.flow.InstalledAppFlow.run_console

.. versionadded:: 0.2.0
dialect : str, default 'standard'
Expand Down Expand Up @@ -745,13 +771,6 @@ def read_gbq(
<https://cloud.google.com/bigquery/docs/access-control#roles>`__
permission on the project you are billing queries to.

**Note:** Due to a `known issue in the ``google-cloud-bigquery``
package
<https://github.com/googleapis/google-cloud-python/pull/7633>`__
(fixed in version 1.11.0), you must write your query results to a
destination table. To do this with ``read_gbq``, supply a
``configuration`` dictionary.

This feature requires the ``google-cloud-bigquery-storage`` and
``pyarrow`` packages.

Expand Down Expand Up @@ -823,6 +842,15 @@ def read_gbq(
if dialect not in ("legacy", "standard"):
raise ValueError("'{0}' is not valid for dialect".format(dialect))

if configuration and "query" in configuration and "query" in configuration["query"]:
if query_or_table is not None:
raise ValueError(
"Query statement can't be specified "
"inside config while it is specified "
"as parameter"
)
query_or_table = configuration["query"].pop("query")

connector = GbqConnector(
project_id,
reauth=reauth,
Expand All @@ -834,13 +862,21 @@ def read_gbq(
use_bqstorage_api=use_bqstorage_api,
)

final_df = connector.run_query(
query,
configuration=configuration,
max_results=max_results,
progress_bar_type=progress_bar_type,
dtypes=dtypes,
)
if _is_query(query_or_table):
final_df = connector.run_query(
query_or_table,
configuration=configuration,
max_results=max_results,
progress_bar_type=progress_bar_type,
dtypes=dtypes,
)
else:
final_df = connector.download_table(
query_or_table,
max_results=max_results,
progress_bar_type=progress_bar_type,
dtypes=dtypes,
)

# Reindex the DataFrame on the provided column
if index_col is not None:
Expand Down Expand Up @@ -889,7 +925,7 @@ def to_gbq(

This method uses the Google Cloud client library to make requests to
Google BigQuery, documented `here
<https://google-cloud-python.readthedocs.io/en/latest/bigquery/usage.html>`__.
<https://googleapis.dev/python/bigquery/latest/index.html>`__.

See the :ref:`How to authenticate with Google BigQuery <authentication>`
guide for authentication instructions.
Expand All @@ -902,7 +938,7 @@ def to_gbq(
Name of table to be written, in the form ``dataset.tablename`` or
``project.dataset.tablename``.
project_id : str, optional
Google BigQuery Account project ID. Optional when available from
Google Cloud Platform project ID. Optional when available from
the environment.
chunksize : int, optional
Number of rows to be inserted in each chunk from the dataframe.
Expand Down
8 changes: 7 additions & 1 deletion pandas_gbq/timestamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
Private module.
"""

import pandas.api.types


def localize_df(df, schema_fields):
"""Localize any TIMESTAMP columns to tz-aware type.
Expand Down Expand Up @@ -38,7 +40,11 @@ def localize_df(df, schema_fields):
if "mode" in field and field["mode"].upper() == "REPEATED":
continue

if field["type"].upper() == "TIMESTAMP" and df[column].dt.tz is None:
if (
field["type"].upper() == "TIMESTAMP"
and pandas.api.types.is_datetime64_ns_dtype(df.dtypes[column])
and df[column].dt.tz is None
):
df[column] = df[column].dt.tz_localize("UTC")

return df
19 changes: 19 additions & 0 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# license that can be found in the LICENSE file.

import os
import functools
import pathlib

from google.cloud import bigquery
Expand Down Expand Up @@ -56,6 +57,24 @@ def project(project_id):
return project_id


@pytest.fixture
def to_gbq(credentials, project_id):
import pandas_gbq

return functools.partial(
pandas_gbq.to_gbq, project_id=project_id, credentials=credentials
)


@pytest.fixture
def read_gbq(credentials, project_id):
import pandas_gbq

return functools.partial(
pandas_gbq.read_gbq, project_id=project_id, credentials=credentials
)


@pytest.fixture()
def random_dataset_id(bigquery_client: bigquery.Client, project_id: str):
dataset_id = prefixer.create_prefix()
Expand Down
19 changes: 7 additions & 12 deletions tests/system/test_to_gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import datetime
import decimal
import collections
import functools
import random

import db_dtypes
Expand All @@ -23,12 +22,8 @@ def api_method(request):


@pytest.fixture
def method_under_test(credentials, project_id):
import pandas_gbq

return functools.partial(
pandas_gbq.to_gbq, project_id=project_id, credentials=credentials
)
def method_under_test(to_gbq):
return to_gbq


SeriesRoundTripTestCase = collections.namedtuple(
Expand Down Expand Up @@ -98,7 +93,7 @@ def method_under_test(credentials, project_id):
def test_series_round_trip(
method_under_test,
random_dataset_id,
bigquery_client,
read_gbq,
input_series,
api_method,
api_methods,
Expand All @@ -114,7 +109,7 @@ def test_series_round_trip(
)
method_under_test(df, table_id, api_method=api_method)

round_trip = bigquery_client.list_rows(table_id).to_dataframe()
round_trip = read_gbq(table_id)
round_trip_series = round_trip["test_col"].sort_values().reset_index(drop=True)
pandas.testing.assert_series_equal(
round_trip_series, input_series, check_exact=True, check_names=False,
Expand Down Expand Up @@ -244,8 +239,8 @@ def test_series_round_trip(
)
def test_dataframe_round_trip_with_table_schema(
method_under_test,
read_gbq,
random_dataset_id,
bigquery_client,
input_df,
expected_df,
table_schema,
Expand All @@ -260,8 +255,8 @@ def test_dataframe_round_trip_with_table_schema(
method_under_test(
input_df, table_id, table_schema=table_schema, api_method=api_method
)
round_trip = bigquery_client.list_rows(table_id).to_dataframe(
dtypes=dict(zip(expected_df.columns, expected_df.dtypes))
round_trip = read_gbq(
table_id, dtypes=dict(zip(expected_df.columns, expected_df.dtypes)),
)
round_trip.sort_values("row_num", inplace=True)
pandas.testing.assert_frame_equal(expected_df, round_trip)
23 changes: 20 additions & 3 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,35 @@ def mock_bigquery_client(monkeypatch):
# Constructor returns the mock itself, so this mock can be treated as the
# constructor or the instance.
mock_client.return_value = mock_client
mock_schema = [google.cloud.bigquery.SchemaField("_f0", "INTEGER")]
# Mock out SELECT 1 query results.

mock_query = mock.create_autospec(google.cloud.bigquery.QueryJob)
mock_query.job_id = "some-random-id"
mock_query.state = "DONE"
mock_rows = mock.create_autospec(google.cloud.bigquery.table.RowIterator)
mock_rows.total_rows = 1
mock_rows.schema = mock_schema

mock_rows.__iter__.return_value = [(1,)]
mock_query.result.return_value = mock_rows
mock_client.list_rows.return_value = mock_rows
mock_client.query.return_value = mock_query
# Mock table creation.
monkeypatch.setattr(google.cloud.bigquery, "Client", mock_client)
mock_client.reset_mock()

# Mock out SELECT 1 query results.
def generate_schema():
query = mock_client.query.call_args[0][0] if mock_client.query.call_args else ""
if query == "SELECT 1 AS int_col":
return [google.cloud.bigquery.SchemaField("int_col", "INTEGER")]
else:
return [google.cloud.bigquery.SchemaField("_f0", "INTEGER")]

type(mock_rows).schema = mock.PropertyMock(side_effect=generate_schema)

# Mock out get_table.
def get_table(table_ref_or_id, **kwargs):
return google.cloud.bigquery.Table(table_ref_or_id)

mock_client.get_table.side_effect = get_table

return mock_client
Loading