From a07397a8e53161fab6df61a7654dc369603c2d86 Mon Sep 17 00:00:00 2001 From: Peter Lamut Date: Tue, 1 Oct 2019 19:51:19 +0200 Subject: [PATCH] Improve and refactor pyarrow schema detection Add more pyarrow types, convert to pyarrow only the columns the schema could not be detected for, etc. --- .../google/cloud/bigquery/_pandas_helpers.py | 124 ++++++++++----- bigquery/tests/unit/test__pandas_helpers.py | 145 ++++++++++++++++-- 2 files changed, 218 insertions(+), 51 deletions(-) diff --git a/bigquery/google/cloud/bigquery/_pandas_helpers.py b/bigquery/google/cloud/bigquery/_pandas_helpers.py index 7bac15f78556..b98cb7d833d2 100644 --- a/bigquery/google/cloud/bigquery/_pandas_helpers.py +++ b/bigquery/google/cloud/bigquery/_pandas_helpers.py @@ -110,13 +110,35 @@ def pyarrow_timestamp(): "TIME": pyarrow_time, "TIMESTAMP": pyarrow_timestamp, } - ARROW_SCALARS_TO_BQ = { - arrow_type(): bq_type # TODO: explain wht calling arrow_type() - for bq_type, arrow_type in BQ_TO_ARROW_SCALARS.items() + ARROW_SCALAR_IDS_TO_BQ = { + # https://arrow.apache.org/docs/python/api/datatypes.html#type-classes + pyarrow.bool_().id: "BOOL", + pyarrow.int8().id: "INT64", + pyarrow.int16().id: "INT64", + pyarrow.int32().id: "INT64", + pyarrow.int64().id: "INT64", + pyarrow.uint8().id: "INT64", + pyarrow.uint16().id: "INT64", + pyarrow.uint32().id: "INT64", + pyarrow.uint64().id: "INT64", + pyarrow.float16().id: "FLOAT64", + pyarrow.float32().id: "FLOAT64", + pyarrow.float64().id: "FLOAT64", + pyarrow.time32("ms").id: "TIME", + pyarrow.time64("ns").id: "TIME", + pyarrow.timestamp("ns").id: "TIMESTAMP", + pyarrow.date32().id: "DATE", + pyarrow.date64().id: "DATETIME", # because millisecond resolution + pyarrow.binary().id: "BYTES", + pyarrow.string().id: "STRING", # also alias for pyarrow.utf8() + pyarrow.decimal128(38, scale=9).id: "NUMERIC", + # The exact decimal's scale and precision are not important, as only + # the type ID matters, and it's the same for all decimal128 instances. } + else: # pragma: NO COVER BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER - ARROW_SCALARS_TO_BQ = {} # pragma: NO_COVER + ARROW_SCALAR_IDS_TO_BQ = {} # pragma: NO_COVER def bq_to_arrow_struct_data_type(field): @@ -269,6 +291,8 @@ def dataframe_to_bq_schema(dataframe, bq_schema): bq_schema_unused = set() bq_schema_out = [] + unknown_type_fields = [] + for column, dtype in list_columns_and_indexes(dataframe): # Use provided type from schema, if present. bq_field = bq_schema_index.get(column) @@ -280,12 +304,12 @@ def dataframe_to_bq_schema(dataframe, bq_schema): # Otherwise, try to automatically determine the type based on the # pandas dtype. bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name) - if not bq_type: - warnings.warn(u"Unable to determine type of column '{}'.".format(column)) - bq_field = schema.SchemaField(column, bq_type) bq_schema_out.append(bq_field) + if bq_field.field_type is None: + unknown_type_fields.append(bq_field) + # Catch any schema mismatch. The developer explicitly asked to serialize a # column, but it was not found. if bq_schema_unused: @@ -297,42 +321,70 @@ def dataframe_to_bq_schema(dataframe, bq_schema): # If schema detection was not successful for all columns, also try with # pyarrow, if available. - if any(field.field_type is None for field in bq_schema_out): + if unknown_type_fields: if not pyarrow: + msg = u"Could not determine the type of columns: {}".format( + ", ".join(field.name for field in unknown_type_fields) + ) + warnings.warn(msg) return None # We cannot detect the schema in full. - arrow_table = dataframe_to_arrow(dataframe, bq_schema_out) - arrow_schema_index = {field.name: field.type for field in arrow_table} + # The currate_schema() helper itself will also issue unknown type + # warnings if detection still fails for any of the fields. + bq_schema_out = currate_schema(dataframe, bq_schema_out) - currated_schema = [] - for schema_field in bq_schema_out: - if schema_field.field_type is not None: - currated_schema.append(schema_field) - continue + return tuple(bq_schema_out) if bq_schema_out else None - detected_type = ARROW_SCALARS_TO_BQ.get( - arrow_schema_index.get(schema_field.name) - ) - if detected_type is None: - warnings.warn( - u"Pyarrow could not determine the type of column '{}'.".format( - schema_field.name - ) - ) - return None - - new_field = schema.SchemaField( - name=schema_field.name, - field_type=detected_type, - mode=schema_field.mode, - description=schema_field.description, - fields=schema_field.fields, - ) - currated_schema.append(new_field) - bq_schema_out = currated_schema +def currate_schema(dataframe, current_bq_schema): + """Try to deduce the unknown field types and return an improved schema. + + This function requires ``pyarrow`` to run. If all the missing types still + cannot be detected, ``None`` is returned. If all types are already known, + a shallow copy of the given schema is returned. + + Args: + dataframe (pandas.DataFrame): + DataFrame for which some of the field types are still unknown. + current_bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]): + A BigQuery schema for ``dataframe``. The types of some or all of + the fields may be ``None``. + Returns: + Optional[Sequence[google.cloud.bigquery.schema.SchemaField]] + """ + currated_schema = [] + unknown_type_fields = [] + + for field in current_bq_schema: + if field.field_type is not None: + currated_schema.append(field) + continue + + arrow_table = pyarrow.array(dataframe[field.name]) + detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.type.id) + + if detected_type is None: + unknown_type_fields.append(field) + continue + + new_field = schema.SchemaField( + name=field.name, + field_type=detected_type, + mode=field.mode, + description=field.description, + fields=field.fields, + ) + currated_schema.append(new_field) + + if unknown_type_fields: + warnings.warn( + u"Pyarrow could not determine the type of columns: {}.".format( + ", ".join(field.name for field in unknown_type_fields) + ) + ) + return None - return tuple(bq_schema_out) + return currated_schema def dataframe_to_arrow(dataframe, bq_schema): diff --git a/bigquery/tests/unit/test__pandas_helpers.py b/bigquery/tests/unit/test__pandas_helpers.py index 3ffbfc002a79..0f323b257b26 100644 --- a/bigquery/tests/unit/test__pandas_helpers.py +++ b/bigquery/tests/unit/test__pandas_helpers.py @@ -16,6 +16,7 @@ import datetime import decimal import functools +import operator import warnings import mock @@ -911,38 +912,57 @@ def test_dataframe_to_parquet_compression_method(module_under_test): def test_dataframe_to_bq_schema_fallback_needed_wo_pyarrow(module_under_test): dataframe = pandas.DataFrame( data=[ - {"id": 10, "status": "FOO", "execution_date": datetime.date(2019, 5, 10)}, - {"id": 20, "status": "BAR", "created_at": datetime.date(2018, 9, 12)}, + {"id": 10, "status": u"FOO", "execution_date": datetime.date(2019, 5, 10)}, + {"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)}, ] ) no_pyarrow_patch = mock.patch(module_under_test.__name__ + ".pyarrow", None) - with no_pyarrow_patch: + with no_pyarrow_patch, warnings.catch_warnings(record=True) as warned: detected_schema = module_under_test.dataframe_to_bq_schema( dataframe, bq_schema=[] ) assert detected_schema is None + # a warning should also be issued + expected_warnings = [ + warning for warning in warned if "could not determine" in str(warning).lower() + ] + assert len(expected_warnings) == 1 + msg = str(expected_warnings[0]) + assert "execution_date" in msg and "created_at" in msg + @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") @pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow(module_under_test): dataframe = pandas.DataFrame( data=[ - {"id": 10, "status": "FOO", "created_at": datetime.date(2019, 5, 10)}, - {"id": 20, "status": "BAR", "created_at": datetime.date(2018, 9, 12)}, + {"id": 10, "status": u"FOO", "created_at": datetime.date(2019, 5, 10)}, + {"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)}, ] ) - detected_schema = module_under_test.dataframe_to_bq_schema(dataframe, bq_schema=[]) + with warnings.catch_warnings(record=True) as warned: + detected_schema = module_under_test.dataframe_to_bq_schema( + dataframe, bq_schema=[] + ) + expected_schema = ( schema.SchemaField("id", "INTEGER", mode="NULLABLE"), schema.SchemaField("status", "STRING", mode="NULLABLE"), schema.SchemaField("created_at", "DATE", mode="NULLABLE"), ) - assert detected_schema == expected_schema + by_name = operator.attrgetter("name") + assert sorted(detected_schema, key=by_name) == sorted(expected_schema, key=by_name) + + # there should be no relevant warnings + unwanted_warnings = [ + warning for warning in warned if "could not determine" in str(warning).lower() + ] + assert not unwanted_warnings @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") @@ -950,8 +970,8 @@ def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow(module_under_test): def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test): dataframe = pandas.DataFrame( data=[ - {"id": 10, "status": "FOO", "all_items": [10.1, 10.2]}, - {"id": 20, "status": "BAR", "all_items": [20.1, 20.2]}, + {"struct_field": {"one": 2}, "status": u"FOO"}, + {"struct_field": {"two": u"222"}, "status": u"BAR"}, ] ) @@ -962,12 +982,107 @@ def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test): assert detected_schema is None - expected_warnings = [] - for warning in warned: - if "Pyarrow could not" in str(warning): - expected_warnings.append(warning) + # a warning should also be issued + expected_warnings = [ + warning for warning in warned if "could not determine" in str(warning).lower() + ] + assert len(expected_warnings) == 1 + assert "struct_field" in str(expected_warnings[0]) + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") +def test_currate_schema_type_detection_succeeds(module_under_test): + dataframe = pandas.DataFrame( + data=[ + { + "bool_field": False, + "int_field": 123, + "float_field": 3.141592, + "time_field": datetime.time(17, 59, 47), + "timestamp_field": datetime.datetime(2005, 5, 31, 14, 25, 55), + "date_field": datetime.date(2005, 5, 31), + "bytes_field": b"some bytes", + "string_field": u"some characters", + "numeric_field": decimal.Decimal("123.456"), + } + ] + ) + + # NOTE: In Pandas dataframe, the dtype of Python's datetime instances is + # set to "datetime64[ns]", and pyarrow converts that to pyarrow.TimestampArray. + # We thus cannot expect to get a DATETIME date when converting back to the + # BigQuery type. + + current_schema = ( + schema.SchemaField("bool_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("int_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("float_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("time_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("timestamp_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("date_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("bytes_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("string_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("numeric_field", field_type=None, mode="NULLABLE"), + ) + + with warnings.catch_warnings(record=True) as warned: + currated_schema = module_under_test.currate_schema(dataframe, current_schema) + # there should be no relevant warnings + unwanted_warnings = [ + warning for warning in warned if "Pyarrow could not" in str(warning) + ] + assert not unwanted_warnings + + # the currated schema must match the expected + expected_schema = ( + schema.SchemaField("bool_field", field_type="BOOL", mode="NULLABLE"), + schema.SchemaField("int_field", field_type="INT64", mode="NULLABLE"), + schema.SchemaField("float_field", field_type="FLOAT64", mode="NULLABLE"), + schema.SchemaField("time_field", field_type="TIME", mode="NULLABLE"), + schema.SchemaField("timestamp_field", field_type="TIMESTAMP", mode="NULLABLE"), + schema.SchemaField("date_field", field_type="DATE", mode="NULLABLE"), + schema.SchemaField("bytes_field", field_type="BYTES", mode="NULLABLE"), + schema.SchemaField("string_field", field_type="STRING", mode="NULLABLE"), + schema.SchemaField("numeric_field", field_type="NUMERIC", mode="NULLABLE"), + ) + by_name = operator.attrgetter("name") + assert sorted(currated_schema, key=by_name) == sorted(expected_schema, key=by_name) + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") +def test_currate_schema_type_detection_fails(module_under_test): + dataframe = pandas.DataFrame( + data=[ + { + "status": u"FOO", + "struct_field": {"one": 1}, + "struct_field_2": {"foo": u"123"}, + }, + { + "status": u"BAR", + "struct_field": {"two": u"111"}, + "struct_field_2": {"bar": 27}, + }, + ] + ) + current_schema = [ + schema.SchemaField("status", field_type="STRING", mode="NULLABLE"), + schema.SchemaField("struct_field", field_type=None, mode="NULLABLE"), + schema.SchemaField("struct_field_2", field_type=None, mode="NULLABLE"), + ] + + with warnings.catch_warnings(record=True) as warned: + currated_schema = module_under_test.currate_schema(dataframe, current_schema) + + assert currated_schema is None + + expected_warnings = [ + warning for warning in warned if "could not determine" in str(warning) + ] assert len(expected_warnings) == 1 warning_msg = str(expected_warnings[0]) - assert "all_items" in warning_msg - assert "could not determine the type" in warning_msg + assert "pyarrow" in warning_msg.lower() + assert "struct_field" in warning_msg and "struct_field_2" in warning_msg