From aa38e4206e115e13c0a6c16c98276821ad443cea Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 29 May 2019 17:38:10 -0700 Subject: [PATCH] Support array and struct data type conversions. --- .../google/cloud/bigquery/_pandas_helpers.py | 60 ++++-- bigquery/tests/system.py | 11 +- bigquery/tests/unit/test__pandas_helpers.py | 191 ++++++++++++++++-- 3 files changed, 223 insertions(+), 39 deletions(-) diff --git a/bigquery/google/cloud/bigquery/_pandas_helpers.py b/bigquery/google/cloud/bigquery/_pandas_helpers.py index d5c76935c7e6..f1353cabb7f1 100644 --- a/bigquery/google/cloud/bigquery/_pandas_helpers.py +++ b/bigquery/google/cloud/bigquery/_pandas_helpers.py @@ -20,6 +20,11 @@ except ImportError: # pragma: NO COVER pyarrow = None +from google.cloud.bigquery import schema + + +STRUCT_TYPES = ("RECORD", "STRUCT") + def pyarrow_datetime(): return pyarrow.timestamp("us", tz=None) @@ -37,8 +42,7 @@ def pyarrow_timestamp(): return pyarrow.timestamp("us", tz="UTC") -BQ_TO_ARROW_SCALARS = {} -if pyarrow is not None: # pragma: NO COVER +if pyarrow: # pragma: NO COVER BQ_TO_ARROW_SCALARS = { "BOOL": pyarrow.bool_, "BOOLEAN": pyarrow.bool_, @@ -55,6 +59,8 @@ def pyarrow_timestamp(): "TIME": pyarrow_time, "TIMESTAMP": pyarrow_timestamp, } +else: + BQ_TO_ARROW_SCALARS = {} def bq_to_arrow_data_type(field): @@ -62,12 +68,17 @@ def bq_to_arrow_data_type(field): Returns None if default Arrow type inspection should be used. """ - # TODO: Use pyarrow.list_(item_type) for repeated (array) fields. if field.mode is not None and field.mode.upper() == "REPEATED": + inner_type = bq_to_arrow_data_type( + schema.SchemaField(field.name, field.field_type) + ) + if inner_type: + return pyarrow.list_(inner_type) return None - # TODO: Use pyarrow.struct(fields) for record (struct) fields. - if field.field_type.upper() in ("RECORD", "STRUCT"): - return None + + if field.field_type.upper() in STRUCT_TYPES: + arrow_fields = [bq_to_arrow_field(subfield) for subfield in field.fields] + return pyarrow.struct(arrow_fields) data_type_constructor = BQ_TO_ARROW_SCALARS.get(field.field_type.upper()) if data_type_constructor is None: @@ -75,6 +86,27 @@ def bq_to_arrow_data_type(field): return data_type_constructor() +def bq_to_arrow_field(bq_field): + """Return the Arrow field, corresponding to a given BigQuery column. + + Returns None if the Arrow type cannot be determined. + """ + arrow_type = bq_to_arrow_data_type(bq_field) + if arrow_type: + is_nullable = bq_field.mode.upper() == "NULLABLE" + return pyarrow.field(bq_field.name, arrow_type, nullable=is_nullable) + return None + + +def bq_to_arrow_array(series, bq_field): + arrow_type = bq_to_arrow_data_type(bq_field) + if bq_field.mode.upper() == "REPEATED": + return pyarrow.ListArray.from_pandas(series, type=arrow_type) + if bq_field.field_type.upper() in STRUCT_TYPES: + return pyarrow.StructArray.from_pandas(series, type=arrow_type) + return pyarrow.array(series, type=arrow_type) + + def to_parquet(dataframe, bq_schema, filepath): """Write dataframe as a Parquet file, according to the desired BQ schema. @@ -91,22 +123,18 @@ def to_parquet(dataframe, bq_schema, filepath): Path to write Parquet file to. """ if pyarrow is None: - raise ValueError("pyarrow is required for BigQuery schema conversion") + raise ValueError("pyarrow is required for BigQuery schema conversion.") if len(bq_schema) != len(dataframe.columns): raise ValueError( - "Number of columns in schema must match number of columns in dataframe" + "Number of columns in schema must match number of columns in dataframe." ) arrow_arrays = [] - column_names = [] + arrow_names = [] for bq_field in bq_schema: - column_names.append(bq_field.name) - arrow_arrays.append( - pyarrow.array( - dataframe[bq_field.name], type=bq_to_arrow_data_type(bq_field) - ) - ) + arrow_names.append(bq_field.name) + arrow_arrays.append(bq_to_arrow_array(dataframe[bq_field.name], bq_field)) - arrow_table = pyarrow.Table.from_arrays(arrow_arrays, names=column_names) + arrow_table = pyarrow.Table.from_arrays(arrow_arrays, names=arrow_names) pyarrow.parquet.write_table(arrow_table, filepath) diff --git a/bigquery/tests/system.py b/bigquery/tests/system.py index 8960fe93f4cd..7e7b356f4f8a 100644 --- a/bigquery/tests/system.py +++ b/bigquery/tests/system.py @@ -635,7 +635,7 @@ def test_load_table_from_dataframe_w_nulls(self): See: https://github.com/googleapis/google-cloud-python/issues/7370 """ # Schema with all scalar types. - table_schema = ( + scalars_schema = ( bigquery.SchemaField("bool_col", "BOOLEAN"), bigquery.SchemaField("bytes_col", "BYTES"), bigquery.SchemaField("date_col", "DATE"), @@ -648,6 +648,15 @@ def test_load_table_from_dataframe_w_nulls(self): bigquery.SchemaField("time_col", "TIME"), bigquery.SchemaField("ts_col", "TIMESTAMP"), ) + table_schema = scalars_schema + ( + # TODO: Array columns can't be read due to NULLABLE versus REPEATED + # mode mismatch. See: + # https://issuetracker.google.com/133415569#comment3 + # bigquery.SchemaField("array_col", "INTEGER", mode="REPEATED"), + # TODO: Support writing StructArrays to Parquet. See: + # https://jira.apache.org/jira/browse/ARROW-2587 + # bigquery.SchemaField("struct_col", "RECORD", fields=scalars_schema), + ) num_rows = 100 nulls = [None] * num_rows dataframe = pandas.DataFrame( diff --git a/bigquery/tests/unit/test__pandas_helpers.py b/bigquery/tests/unit/test__pandas_helpers.py index 2b0e3a8b0dbb..a1a4d87ec877 100644 --- a/bigquery/tests/unit/test__pandas_helpers.py +++ b/bigquery/tests/unit/test__pandas_helpers.py @@ -114,34 +114,181 @@ def test_all_(): ("DATETIME", "NULLABLE", is_datetime), ("GEOGRAPHY", "NULLABLE", pyarrow.types.is_string), ("UNKNOWN_TYPE", "NULLABLE", is_none), - # TODO: Use pyarrow.struct(fields) for record (struct) fields. - ("RECORD", "NULLABLE", is_none), - ("STRUCT", "NULLABLE", is_none), - # TODO: Use pyarrow.list_(item_type) for repeated (array) fields. - ("STRING", "REPEATED", is_none), - ("STRING", "repeated", is_none), - ("STRING", "RePeAtEd", is_none), - ("BYTES", "REPEATED", is_none), - ("INTEGER", "REPEATED", is_none), - ("INT64", "REPEATED", is_none), - ("FLOAT", "REPEATED", is_none), - ("FLOAT64", "REPEATED", is_none), - ("NUMERIC", "REPEATED", is_none), - ("BOOLEAN", "REPEATED", is_none), - ("BOOL", "REPEATED", is_none), - ("TIMESTAMP", "REPEATED", is_none), - ("DATE", "REPEATED", is_none), - ("TIME", "REPEATED", is_none), - ("DATETIME", "REPEATED", is_none), - ("GEOGRAPHY", "REPEATED", is_none), + # Use pyarrow.list_(item_type) for repeated (array) fields. + ( + "STRING", + "REPEATED", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_string(type_.value_type), + ), + ), + ( + "STRING", + "repeated", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_string(type_.value_type), + ), + ), + ( + "STRING", + "RePeAtEd", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_string(type_.value_type), + ), + ), + ( + "BYTES", + "REPEATED", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_binary(type_.value_type), + ), + ), + ( + "INTEGER", + "REPEATED", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_int64(type_.value_type), + ), + ), + ( + "INT64", + "REPEATED", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_int64(type_.value_type), + ), + ), + ( + "FLOAT", + "REPEATED", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_float64(type_.value_type), + ), + ), + ( + "FLOAT64", + "REPEATED", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_float64(type_.value_type), + ), + ), + ( + "NUMERIC", + "REPEATED", + all_(pyarrow.types.is_list, lambda type_: is_numeric(type_.value_type)), + ), + ( + "BOOLEAN", + "REPEATED", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_boolean(type_.value_type), + ), + ), + ( + "BOOL", + "REPEATED", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_boolean(type_.value_type), + ), + ), + ( + "TIMESTAMP", + "REPEATED", + all_(pyarrow.types.is_list, lambda type_: is_timestamp(type_.value_type)), + ), + ( + "DATE", + "REPEATED", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_date32(type_.value_type), + ), + ), + ( + "TIME", + "REPEATED", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_time64(type_.value_type), + ), + ), + ( + "DATETIME", + "REPEATED", + all_(pyarrow.types.is_list, lambda type_: is_datetime(type_.value_type)), + ), + ( + "GEOGRAPHY", + "REPEATED", + all_( + pyarrow.types.is_list, + lambda type_: pyarrow.types.is_string(type_.value_type), + ), + ), ("RECORD", "REPEATED", is_none), + ("UNKNOWN_TYPE", "REPEATED", is_none), ], ) @pytest.mark.skipIf(pyarrow is None, "Requires `pyarrow`") def test_bq_to_arrow_data_type(module_under_test, bq_type, bq_mode, is_correct_type): field = schema.SchemaField("ignored_name", bq_type, mode=bq_mode) - got = module_under_test.bq_to_arrow_data_type(field) - assert is_correct_type(got) + actual = module_under_test.bq_to_arrow_data_type(field) + assert is_correct_type(actual) + + +@pytest.mark.parametrize( + "bq_type", [("RECORD",), ("record",), ("STRUCT",), ("struct",)] +) +@pytest.mark.skipIf(pyarrow is None, "Requires `pyarrow`") +def test_bq_to_arrow_data_type_w_struct(module_under_test, bq_type): + fields = ( + schema.SchemaField("field01", "STRING"), + schema.SchemaField("field02", "BYTES"), + schema.SchemaField("field03", "INTEGER"), + schema.SchemaField("field04", "INT64"), + schema.SchemaField("field05", "FLOAT"), + schema.SchemaField("field06", "FLOAT64"), + schema.SchemaField("field07", "NUMERIC"), + schema.SchemaField("field08", "BOOLEAN"), + schema.SchemaField("field09", "BOOL"), + schema.SchemaField("field10", "TIMESTAMP"), + schema.SchemaField("field11", "DATE"), + schema.SchemaField("field12", "TIME"), + schema.SchemaField("field13", "DATETIME"), + schema.SchemaField("field14", "GEOGRAPHY"), + ) + field = schema.SchemaField("ignored_name", "RECORD", mode="NULLABLE", fields=fields) + actual = module_under_test.bq_to_arrow_data_type(field) + expected = pyarrow.struct( + ( + pyarrow.field("field01", pyarrow.string()), + pyarrow.field("field02", pyarrow.binary()), + pyarrow.field("field03", pyarrow.int64()), + pyarrow.field("field04", pyarrow.int64()), + pyarrow.field("field05", pyarrow.float64()), + pyarrow.field("field06", pyarrow.float64()), + pyarrow.field("field07", module_under_test.pyarrow_numeric()), + pyarrow.field("field08", pyarrow.bool_()), + pyarrow.field("field09", pyarrow.bool_()), + pyarrow.field("field10", module_under_test.pyarrow_timestamp()), + pyarrow.field("field11", pyarrow.date32()), + pyarrow.field("field12", module_under_test.pyarrow_time()), + pyarrow.field("field13", module_under_test.pyarrow_datetime()), + pyarrow.field("field14", pyarrow.string()), + ) + ) + assert pyarrow.types.is_struct(actual) + assert actual.num_children == len(fields) + assert actual.equals(expected) @pytest.mark.skipIf(pandas is None, "Requires `pandas`")