Skip to content

Commit

Permalink
Determine the schema in load_table_from_dataframe based on dtypes. (#…
Browse files Browse the repository at this point in the history
…9049)

* Determine the schema in `load_table_from_dataframe` based on dtypes.

This PR updates `load_table_from_dataframe` to automatically determine
the BigQuery schema based on the DataFrame's dtypes. If any field's type
cannot be determined, fallback to the logic in the pandas `to_parquet`
method.

* Fix test coverage.

* Reduce duplication by using OrderedDict

* Add columns option to DataFrame constructor to ensure correct column order.
  • Loading branch information
tswast authored Aug 21, 2019
1 parent c927a72 commit fcf99ce
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 2 deletions.
40 changes: 40 additions & 0 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@

_PROGRESS_INTERVAL = 0.2 # Maximum time between download status checks, in seconds.

_PANDAS_DTYPE_TO_BQ = {
"bool": "BOOLEAN",
"datetime64[ns, UTC]": "TIMESTAMP",
"datetime64[ns]": "DATETIME",
"float32": "FLOAT",
"float64": "FLOAT",
"int8": "INTEGER",
"int16": "INTEGER",
"int32": "INTEGER",
"int64": "INTEGER",
"uint8": "INTEGER",
"uint16": "INTEGER",
"uint32": "INTEGER",
}


class _DownloadState(object):
"""Flag to indicate that a thread should exit early."""
Expand Down Expand Up @@ -172,6 +187,31 @@ def bq_to_arrow_array(series, bq_field):
return pyarrow.array(series, type=arrow_type)


def dataframe_to_bq_schema(dataframe):
"""Convert a pandas DataFrame schema to a BigQuery schema.
TODO(GH#8140): Add bq_schema argument to allow overriding autodetected
schema for a subset of columns.
Args:
dataframe (pandas.DataFrame):
DataFrame to convert to convert to Parquet file.
Returns:
Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]:
The automatically determined schema. Returns None if the type of
any column cannot be determined.
"""
bq_schema = []
for column, dtype in zip(dataframe.columns, dataframe.dtypes):
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
if not bq_type:
return None
bq_field = schema.SchemaField(column, bq_type)
bq_schema.append(bq_field)
return tuple(bq_schema)


def dataframe_to_arrow(dataframe, bq_schema):
"""Convert pandas dataframe to Arrow table, using BigQuery schema.
Expand Down
15 changes: 15 additions & 0 deletions bigquery/google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
except ImportError: # Python 2.7
import collections as collections_abc

import copy
import functools
import gzip
import io
Expand Down Expand Up @@ -1521,11 +1522,25 @@ def load_table_from_dataframe(

if job_config is None:
job_config = job.LoadJobConfig()
else:
# Make a copy so that the job config isn't modified in-place.
job_config_properties = copy.deepcopy(job_config._properties)
job_config = job.LoadJobConfig()
job_config._properties = job_config_properties
job_config.source_format = job.SourceFormat.PARQUET

if location is None:
location = self.location

if not job_config.schema:
autodetected_schema = _pandas_helpers.dataframe_to_bq_schema(dataframe)

# Only use an explicit schema if we were able to determine one
# matching the dataframe. If not, fallback to the pandas to_parquet
# method.
if autodetected_schema:
job_config.schema = autodetected_schema

tmpfd, tmppath = tempfile.mkstemp(suffix="_job_{}.parquet".format(job_id[:8]))
os.close(tmpfd)

Expand Down
76 changes: 76 additions & 0 deletions bigquery/tests/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import base64
import collections
import concurrent.futures
import csv
import datetime
Expand Down Expand Up @@ -634,6 +635,81 @@ def test_load_table_from_local_avro_file_then_dump_table(self):
sorted(row_tuples, key=by_wavelength), sorted(ROWS, key=by_wavelength)
)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_w_automatic_schema(self):
"""Test that a DataFrame with dtypes that map well to BigQuery types
can be uploaded without specifying a schema.
https://github.com/googleapis/google-cloud-python/issues/9044
"""
df_data = collections.OrderedDict(
[
("bool_col", pandas.Series([True, False, True], dtype="bool")),
(
"ts_col",
pandas.Series(
[
datetime.datetime(2010, 1, 2, 3, 44, 50),
datetime.datetime(2011, 2, 3, 14, 50, 59),
datetime.datetime(2012, 3, 14, 15, 16),
],
dtype="datetime64[ns]",
).dt.tz_localize(pytz.utc),
),
(
"dt_col",
pandas.Series(
[
datetime.datetime(2010, 1, 2, 3, 44, 50),
datetime.datetime(2011, 2, 3, 14, 50, 59),
datetime.datetime(2012, 3, 14, 15, 16),
],
dtype="datetime64[ns]",
),
),
("float32_col", pandas.Series([1.0, 2.0, 3.0], dtype="float32")),
("float64_col", pandas.Series([4.0, 5.0, 6.0], dtype="float64")),
("int8_col", pandas.Series([-12, -11, -10], dtype="int8")),
("int16_col", pandas.Series([-9, -8, -7], dtype="int16")),
("int32_col", pandas.Series([-6, -5, -4], dtype="int32")),
("int64_col", pandas.Series([-3, -2, -1], dtype="int64")),
("uint8_col", pandas.Series([0, 1, 2], dtype="uint8")),
("uint16_col", pandas.Series([3, 4, 5], dtype="uint16")),
("uint32_col", pandas.Series([6, 7, 8], dtype="uint32")),
]
)
dataframe = pandas.DataFrame(df_data, columns=df_data.keys())

dataset_id = _make_dataset_id("bq_load_test")
self.temp_dataset(dataset_id)
table_id = "{}.{}.load_table_from_dataframe_w_automatic_schema".format(
Config.CLIENT.project, dataset_id
)

load_job = Config.CLIENT.load_table_from_dataframe(dataframe, table_id)
load_job.result()

table = Config.CLIENT.get_table(table_id)
self.assertEqual(
tuple(table.schema),
(
bigquery.SchemaField("bool_col", "BOOLEAN"),
bigquery.SchemaField("ts_col", "TIMESTAMP"),
bigquery.SchemaField("dt_col", "DATETIME"),
bigquery.SchemaField("float32_col", "FLOAT"),
bigquery.SchemaField("float64_col", "FLOAT"),
bigquery.SchemaField("int8_col", "INTEGER"),
bigquery.SchemaField("int16_col", "INTEGER"),
bigquery.SchemaField("int32_col", "INTEGER"),
bigquery.SchemaField("int64_col", "INTEGER"),
bigquery.SchemaField("uint8_col", "INTEGER"),
bigquery.SchemaField("uint16_col", "INTEGER"),
bigquery.SchemaField("uint32_col", "INTEGER"),
),
)
self.assertEqual(table.num_rows, 3)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_w_nulls(self):
Expand Down
74 changes: 72 additions & 2 deletions bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import copy
import collections
import datetime
import decimal
import email
Expand Down Expand Up @@ -5325,9 +5326,78 @@ def test_load_table_from_dataframe_w_custom_job_config(self):
)

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config is job_config
assert sent_config.source_format == job.SourceFormat.PARQUET

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_w_automatic_schema(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
df_data = collections.OrderedDict(
[
("int_col", [1, 2, 3]),
("float_col", [1.0, 2.0, 3.0]),
("bool_col", [True, False, True]),
(
"dt_col",
pandas.Series(
[
datetime.datetime(2010, 1, 2, 3, 44, 50),
datetime.datetime(2011, 2, 3, 14, 50, 59),
datetime.datetime(2012, 3, 14, 15, 16),
],
dtype="datetime64[ns]",
),
),
(
"ts_col",
pandas.Series(
[
datetime.datetime(2010, 1, 2, 3, 44, 50),
datetime.datetime(2011, 2, 3, 14, 50, 59),
datetime.datetime(2012, 3, 14, 15, 16),
],
dtype="datetime64[ns]",
).dt.tz_localize(pytz.utc),
),
]
)
dataframe = pandas.DataFrame(df_data, columns=df_data.keys())
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)

with load_patch as load_table_from_file:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, location=self.LOCATION
)

load_table_from_file.assert_called_once_with(
client,
mock.ANY,
self.TABLE_REF,
num_retries=_DEFAULT_NUM_RETRIES,
rewind=True,
job_id=mock.ANY,
job_id_prefix=None,
location=self.LOCATION,
project=None,
job_config=mock.ANY,
)

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config.source_format == job.SourceFormat.PARQUET
assert tuple(sent_config.schema) == (
SchemaField("int_col", "INTEGER"),
SchemaField("float_col", "FLOAT"),
SchemaField("bool_col", "BOOLEAN"),
SchemaField("dt_col", "DATETIME"),
SchemaField("ts_col", "TIMESTAMP"),
)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_struct_fields_error(self):
Expand Down Expand Up @@ -5509,7 +5579,7 @@ def test_load_table_from_dataframe_w_nulls(self):
)

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config is job_config
assert sent_config.schema == schema
assert sent_config.source_format == job.SourceFormat.PARQUET

# Low-level tests
Expand Down

0 comments on commit fcf99ce

Please sign in to comment.