Skip to content

Commit

Permalink
Raise with extra or missing columns in load_table_from_dataframe sc…
Browse files Browse the repository at this point in the history
…hema. (googleapis#9096)

I found it to be difficult to debug typos in column/index names in the
schema, so I have hardened the error messages to indicate when unknown
field values are found.
  • Loading branch information
tswast authored and plamut committed Aug 26, 2019
1 parent 964e46b commit ac1beab
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 27 deletions.
28 changes: 26 additions & 2 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,15 +211,18 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
"https://github.com/googleapis/google-cloud-python/issues/8191"
)
bq_schema_index = {field.name: field for field in bq_schema}
bq_schema_unused = set(bq_schema_index.keys())
else:
bq_schema_index = {}
bq_schema_unused = set()

bq_schema_out = []
for column, dtype in zip(dataframe.columns, dataframe.dtypes):
# Use provided type from schema, if present.
bq_field = bq_schema_index.get(column)
if bq_field:
bq_schema_out.append(bq_field)
bq_schema_unused.discard(bq_field.name)
continue

# Otherwise, try to automatically determine the type based on the
Expand All @@ -230,6 +233,15 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
return None
bq_field = schema.SchemaField(column, bq_type)
bq_schema_out.append(bq_field)

# Catch any schema mismatch. The developer explicitly asked to serialize a
# column, but it was not found.
if bq_schema_unused:
raise ValueError(
"bq_schema contains fields not present in dataframe: {}".format(
bq_schema_unused
)
)
return tuple(bq_schema_out)


Expand All @@ -248,9 +260,21 @@ def dataframe_to_arrow(dataframe, bq_schema):
Table containing dataframe data, with schema derived from
BigQuery schema.
"""
if len(bq_schema) != len(dataframe.columns):
column_names = set(dataframe.columns)
bq_field_names = set(field.name for field in bq_schema)

extra_fields = bq_field_names - column_names
if extra_fields:
raise ValueError(
"bq_schema contains fields not present in dataframe: {}".format(
extra_fields
)
)

missing_fields = column_names - bq_field_names
if missing_fields:
raise ValueError(
"Number of columns in schema must match number of columns in dataframe."
"bq_schema is missing fields from dataframe: {}".format(missing_fields)
)

arrow_arrays = []
Expand Down
20 changes: 17 additions & 3 deletions bigquery/tests/unit/test__pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,12 +609,26 @@ def test_dataframe_to_parquet_without_pyarrow(module_under_test, monkeypatch):

@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_parquet_w_missing_columns(module_under_test, monkeypatch):
def test_dataframe_to_parquet_w_extra_fields(module_under_test, monkeypatch):
with pytest.raises(ValueError) as exc_context:
module_under_test.dataframe_to_parquet(
pandas.DataFrame(), (schema.SchemaField("not_found", "STRING"),), None
pandas.DataFrame(), (schema.SchemaField("not_in_df", "STRING"),), None
)
assert "columns in schema must match" in str(exc_context.value)
message = str(exc_context.value)
assert "bq_schema contains fields not present in dataframe" in message
assert "not_in_df" in message


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_parquet_w_missing_fields(module_under_test, monkeypatch):
with pytest.raises(ValueError) as exc_context:
module_under_test.dataframe_to_parquet(
pandas.DataFrame({"not_in_bq": [1, 2, 3]}), (), None
)
message = str(exc_context.value)
assert "bq_schema is missing fields from dataframe" in message
assert "not_in_bq" in message


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
Expand Down
29 changes: 7 additions & 22 deletions bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5517,7 +5517,6 @@ def test_load_table_from_dataframe_w_partial_schema(self):
@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_w_partial_schema_extra_types(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

Expand All @@ -5540,31 +5539,17 @@ def test_load_table_from_dataframe_w_partial_schema_extra_types(self):
SchemaField("unknown_col", "BYTES"),
)
job_config = job.LoadJobConfig(schema=schema)
with load_patch as load_table_from_file:
with load_patch as load_table_from_file, pytest.raises(
ValueError
) as exc_context:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION
)

load_table_from_file.assert_called_once_with(
client,
mock.ANY,
self.TABLE_REF,
num_retries=_DEFAULT_NUM_RETRIES,
rewind=True,
job_id=mock.ANY,
job_id_prefix=None,
location=self.LOCATION,
project=None,
job_config=mock.ANY,
)

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config.source_format == job.SourceFormat.PARQUET
assert tuple(sent_config.schema) == (
SchemaField("int_col", "INTEGER"),
SchemaField("int_as_float_col", "INTEGER"),
SchemaField("string_col", "STRING"),
)
load_table_from_file.assert_not_called()
message = str(exc_context.value)
assert "bq_schema contains fields not present in dataframe" in message
assert "unknown_col" in message

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
Expand Down

0 comments on commit ac1beab

Please sign in to comment.