Skip to content

Commit

Permalink
Improve and refactor pyarrow schema detection
Browse files Browse the repository at this point in the history
Add more pyarrow types, convert to pyarrow only the columns the schema
could not be detected for, etc.
  • Loading branch information
plamut committed Oct 19, 2019
1 parent dd43f6b commit a07397a
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 51 deletions.
124 changes: 88 additions & 36 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,35 @@ def pyarrow_timestamp():
"TIME": pyarrow_time,
"TIMESTAMP": pyarrow_timestamp,
}
ARROW_SCALARS_TO_BQ = {
arrow_type(): bq_type # TODO: explain wht calling arrow_type()
for bq_type, arrow_type in BQ_TO_ARROW_SCALARS.items()
ARROW_SCALAR_IDS_TO_BQ = {
# https://arrow.apache.org/docs/python/api/datatypes.html#type-classes
pyarrow.bool_().id: "BOOL",
pyarrow.int8().id: "INT64",
pyarrow.int16().id: "INT64",
pyarrow.int32().id: "INT64",
pyarrow.int64().id: "INT64",
pyarrow.uint8().id: "INT64",
pyarrow.uint16().id: "INT64",
pyarrow.uint32().id: "INT64",
pyarrow.uint64().id: "INT64",
pyarrow.float16().id: "FLOAT64",
pyarrow.float32().id: "FLOAT64",
pyarrow.float64().id: "FLOAT64",
pyarrow.time32("ms").id: "TIME",
pyarrow.time64("ns").id: "TIME",
pyarrow.timestamp("ns").id: "TIMESTAMP",
pyarrow.date32().id: "DATE",
pyarrow.date64().id: "DATETIME", # because millisecond resolution
pyarrow.binary().id: "BYTES",
pyarrow.string().id: "STRING", # also alias for pyarrow.utf8()
pyarrow.decimal128(38, scale=9).id: "NUMERIC",
# The exact decimal's scale and precision are not important, as only
# the type ID matters, and it's the same for all decimal128 instances.
}

else: # pragma: NO COVER
BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER
ARROW_SCALARS_TO_BQ = {} # pragma: NO_COVER
ARROW_SCALAR_IDS_TO_BQ = {} # pragma: NO_COVER


def bq_to_arrow_struct_data_type(field):
Expand Down Expand Up @@ -269,6 +291,8 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
bq_schema_unused = set()

bq_schema_out = []
unknown_type_fields = []

for column, dtype in list_columns_and_indexes(dataframe):
# Use provided type from schema, if present.
bq_field = bq_schema_index.get(column)
Expand All @@ -280,12 +304,12 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
# Otherwise, try to automatically determine the type based on the
# pandas dtype.
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
if not bq_type:
warnings.warn(u"Unable to determine type of column '{}'.".format(column))

bq_field = schema.SchemaField(column, bq_type)
bq_schema_out.append(bq_field)

if bq_field.field_type is None:
unknown_type_fields.append(bq_field)

# Catch any schema mismatch. The developer explicitly asked to serialize a
# column, but it was not found.
if bq_schema_unused:
Expand All @@ -297,42 +321,70 @@ def dataframe_to_bq_schema(dataframe, bq_schema):

# If schema detection was not successful for all columns, also try with
# pyarrow, if available.
if any(field.field_type is None for field in bq_schema_out):
if unknown_type_fields:
if not pyarrow:
msg = u"Could not determine the type of columns: {}".format(
", ".join(field.name for field in unknown_type_fields)
)
warnings.warn(msg)
return None # We cannot detect the schema in full.

arrow_table = dataframe_to_arrow(dataframe, bq_schema_out)
arrow_schema_index = {field.name: field.type for field in arrow_table}
# The currate_schema() helper itself will also issue unknown type
# warnings if detection still fails for any of the fields.
bq_schema_out = currate_schema(dataframe, bq_schema_out)

currated_schema = []
for schema_field in bq_schema_out:
if schema_field.field_type is not None:
currated_schema.append(schema_field)
continue
return tuple(bq_schema_out) if bq_schema_out else None

detected_type = ARROW_SCALARS_TO_BQ.get(
arrow_schema_index.get(schema_field.name)
)
if detected_type is None:
warnings.warn(
u"Pyarrow could not determine the type of column '{}'.".format(
schema_field.name
)
)
return None

new_field = schema.SchemaField(
name=schema_field.name,
field_type=detected_type,
mode=schema_field.mode,
description=schema_field.description,
fields=schema_field.fields,
)
currated_schema.append(new_field)

bq_schema_out = currated_schema
def currate_schema(dataframe, current_bq_schema):
"""Try to deduce the unknown field types and return an improved schema.
This function requires ``pyarrow`` to run. If all the missing types still
cannot be detected, ``None`` is returned. If all types are already known,
a shallow copy of the given schema is returned.
Args:
dataframe (pandas.DataFrame):
DataFrame for which some of the field types are still unknown.
current_bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
A BigQuery schema for ``dataframe``. The types of some or all of
the fields may be ``None``.
Returns:
Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]
"""
currated_schema = []
unknown_type_fields = []

for field in current_bq_schema:
if field.field_type is not None:
currated_schema.append(field)
continue

arrow_table = pyarrow.array(dataframe[field.name])
detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.type.id)

if detected_type is None:
unknown_type_fields.append(field)
continue

new_field = schema.SchemaField(
name=field.name,
field_type=detected_type,
mode=field.mode,
description=field.description,
fields=field.fields,
)
currated_schema.append(new_field)

if unknown_type_fields:
warnings.warn(
u"Pyarrow could not determine the type of columns: {}.".format(
", ".join(field.name for field in unknown_type_fields)
)
)
return None

return tuple(bq_schema_out)
return currated_schema


def dataframe_to_arrow(dataframe, bq_schema):
Expand Down
145 changes: 130 additions & 15 deletions bigquery/tests/unit/test__pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import datetime
import decimal
import functools
import operator
import warnings

import mock
Expand Down Expand Up @@ -911,47 +912,66 @@ def test_dataframe_to_parquet_compression_method(module_under_test):
def test_dataframe_to_bq_schema_fallback_needed_wo_pyarrow(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"id": 10, "status": "FOO", "execution_date": datetime.date(2019, 5, 10)},
{"id": 20, "status": "BAR", "created_at": datetime.date(2018, 9, 12)},
{"id": 10, "status": u"FOO", "execution_date": datetime.date(2019, 5, 10)},
{"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)},
]
)

no_pyarrow_patch = mock.patch(module_under_test.__name__ + ".pyarrow", None)

with no_pyarrow_patch:
with no_pyarrow_patch, warnings.catch_warnings(record=True) as warned:
detected_schema = module_under_test.dataframe_to_bq_schema(
dataframe, bq_schema=[]
)

assert detected_schema is None

# a warning should also be issued
expected_warnings = [
warning for warning in warned if "could not determine" in str(warning).lower()
]
assert len(expected_warnings) == 1
msg = str(expected_warnings[0])
assert "execution_date" in msg and "created_at" in msg


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"id": 10, "status": "FOO", "created_at": datetime.date(2019, 5, 10)},
{"id": 20, "status": "BAR", "created_at": datetime.date(2018, 9, 12)},
{"id": 10, "status": u"FOO", "created_at": datetime.date(2019, 5, 10)},
{"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)},
]
)

detected_schema = module_under_test.dataframe_to_bq_schema(dataframe, bq_schema=[])
with warnings.catch_warnings(record=True) as warned:
detected_schema = module_under_test.dataframe_to_bq_schema(
dataframe, bq_schema=[]
)

expected_schema = (
schema.SchemaField("id", "INTEGER", mode="NULLABLE"),
schema.SchemaField("status", "STRING", mode="NULLABLE"),
schema.SchemaField("created_at", "DATE", mode="NULLABLE"),
)
assert detected_schema == expected_schema
by_name = operator.attrgetter("name")
assert sorted(detected_schema, key=by_name) == sorted(expected_schema, key=by_name)

# there should be no relevant warnings
unwanted_warnings = [
warning for warning in warned if "could not determine" in str(warning).lower()
]
assert not unwanted_warnings


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"id": 10, "status": "FOO", "all_items": [10.1, 10.2]},
{"id": 20, "status": "BAR", "all_items": [20.1, 20.2]},
{"struct_field": {"one": 2}, "status": u"FOO"},
{"struct_field": {"two": u"222"}, "status": u"BAR"},
]
)

Expand All @@ -962,12 +982,107 @@ def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test):

assert detected_schema is None

expected_warnings = []
for warning in warned:
if "Pyarrow could not" in str(warning):
expected_warnings.append(warning)
# a warning should also be issued
expected_warnings = [
warning for warning in warned if "could not determine" in str(warning).lower()
]
assert len(expected_warnings) == 1
assert "struct_field" in str(expected_warnings[0])


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_currate_schema_type_detection_succeeds(module_under_test):
dataframe = pandas.DataFrame(
data=[
{
"bool_field": False,
"int_field": 123,
"float_field": 3.141592,
"time_field": datetime.time(17, 59, 47),
"timestamp_field": datetime.datetime(2005, 5, 31, 14, 25, 55),
"date_field": datetime.date(2005, 5, 31),
"bytes_field": b"some bytes",
"string_field": u"some characters",
"numeric_field": decimal.Decimal("123.456"),
}
]
)

# NOTE: In Pandas dataframe, the dtype of Python's datetime instances is
# set to "datetime64[ns]", and pyarrow converts that to pyarrow.TimestampArray.
# We thus cannot expect to get a DATETIME date when converting back to the
# BigQuery type.

current_schema = (
schema.SchemaField("bool_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("int_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("float_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("time_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("timestamp_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("date_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("bytes_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("string_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("numeric_field", field_type=None, mode="NULLABLE"),
)

with warnings.catch_warnings(record=True) as warned:
currated_schema = module_under_test.currate_schema(dataframe, current_schema)

# there should be no relevant warnings
unwanted_warnings = [
warning for warning in warned if "Pyarrow could not" in str(warning)
]
assert not unwanted_warnings

# the currated schema must match the expected
expected_schema = (
schema.SchemaField("bool_field", field_type="BOOL", mode="NULLABLE"),
schema.SchemaField("int_field", field_type="INT64", mode="NULLABLE"),
schema.SchemaField("float_field", field_type="FLOAT64", mode="NULLABLE"),
schema.SchemaField("time_field", field_type="TIME", mode="NULLABLE"),
schema.SchemaField("timestamp_field", field_type="TIMESTAMP", mode="NULLABLE"),
schema.SchemaField("date_field", field_type="DATE", mode="NULLABLE"),
schema.SchemaField("bytes_field", field_type="BYTES", mode="NULLABLE"),
schema.SchemaField("string_field", field_type="STRING", mode="NULLABLE"),
schema.SchemaField("numeric_field", field_type="NUMERIC", mode="NULLABLE"),
)
by_name = operator.attrgetter("name")
assert sorted(currated_schema, key=by_name) == sorted(expected_schema, key=by_name)


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_currate_schema_type_detection_fails(module_under_test):
dataframe = pandas.DataFrame(
data=[
{
"status": u"FOO",
"struct_field": {"one": 1},
"struct_field_2": {"foo": u"123"},
},
{
"status": u"BAR",
"struct_field": {"two": u"111"},
"struct_field_2": {"bar": 27},
},
]
)
current_schema = [
schema.SchemaField("status", field_type="STRING", mode="NULLABLE"),
schema.SchemaField("struct_field", field_type=None, mode="NULLABLE"),
schema.SchemaField("struct_field_2", field_type=None, mode="NULLABLE"),
]

with warnings.catch_warnings(record=True) as warned:
currated_schema = module_under_test.currate_schema(dataframe, current_schema)

assert currated_schema is None

expected_warnings = [
warning for warning in warned if "could not determine" in str(warning)
]
assert len(expected_warnings) == 1
warning_msg = str(expected_warnings[0])
assert "all_items" in warning_msg
assert "could not determine the type" in warning_msg
assert "pyarrow" in warning_msg.lower()
assert "struct_field" in warning_msg and "struct_field_2" in warning_msg

0 comments on commit a07397a

Please sign in to comment.