Skip to content

Commit

Permalink
Support array and struct data type conversions.
Browse files Browse the repository at this point in the history
  • Loading branch information
tswast committed May 30, 2019
1 parent 58e59ea commit aa38e42
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 39 deletions.
60 changes: 44 additions & 16 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
except ImportError: # pragma: NO COVER
pyarrow = None

from google.cloud.bigquery import schema


STRUCT_TYPES = ("RECORD", "STRUCT")


def pyarrow_datetime():
return pyarrow.timestamp("us", tz=None)
Expand All @@ -37,8 +42,7 @@ def pyarrow_timestamp():
return pyarrow.timestamp("us", tz="UTC")


BQ_TO_ARROW_SCALARS = {}
if pyarrow is not None: # pragma: NO COVER
if pyarrow: # pragma: NO COVER
BQ_TO_ARROW_SCALARS = {
"BOOL": pyarrow.bool_,
"BOOLEAN": pyarrow.bool_,
Expand All @@ -55,26 +59,54 @@ def pyarrow_timestamp():
"TIME": pyarrow_time,
"TIMESTAMP": pyarrow_timestamp,
}
else:
BQ_TO_ARROW_SCALARS = {}


def bq_to_arrow_data_type(field):
"""Return the Arrow data type, corresponding to a given BigQuery column.
Returns None if default Arrow type inspection should be used.
"""
# TODO: Use pyarrow.list_(item_type) for repeated (array) fields.
if field.mode is not None and field.mode.upper() == "REPEATED":
inner_type = bq_to_arrow_data_type(
schema.SchemaField(field.name, field.field_type)
)
if inner_type:
return pyarrow.list_(inner_type)
return None
# TODO: Use pyarrow.struct(fields) for record (struct) fields.
if field.field_type.upper() in ("RECORD", "STRUCT"):
return None

if field.field_type.upper() in STRUCT_TYPES:
arrow_fields = [bq_to_arrow_field(subfield) for subfield in field.fields]
return pyarrow.struct(arrow_fields)

data_type_constructor = BQ_TO_ARROW_SCALARS.get(field.field_type.upper())
if data_type_constructor is None:
return None
return data_type_constructor()


def bq_to_arrow_field(bq_field):
"""Return the Arrow field, corresponding to a given BigQuery column.
Returns None if the Arrow type cannot be determined.
"""
arrow_type = bq_to_arrow_data_type(bq_field)
if arrow_type:
is_nullable = bq_field.mode.upper() == "NULLABLE"
return pyarrow.field(bq_field.name, arrow_type, nullable=is_nullable)
return None


def bq_to_arrow_array(series, bq_field):
arrow_type = bq_to_arrow_data_type(bq_field)
if bq_field.mode.upper() == "REPEATED":
return pyarrow.ListArray.from_pandas(series, type=arrow_type)
if bq_field.field_type.upper() in STRUCT_TYPES:
return pyarrow.StructArray.from_pandas(series, type=arrow_type)
return pyarrow.array(series, type=arrow_type)


def to_parquet(dataframe, bq_schema, filepath):
"""Write dataframe as a Parquet file, according to the desired BQ schema.
Expand All @@ -91,22 +123,18 @@ def to_parquet(dataframe, bq_schema, filepath):
Path to write Parquet file to.
"""
if pyarrow is None:
raise ValueError("pyarrow is required for BigQuery schema conversion")
raise ValueError("pyarrow is required for BigQuery schema conversion.")

if len(bq_schema) != len(dataframe.columns):
raise ValueError(
"Number of columns in schema must match number of columns in dataframe"
"Number of columns in schema must match number of columns in dataframe."
)

arrow_arrays = []
column_names = []
arrow_names = []
for bq_field in bq_schema:
column_names.append(bq_field.name)
arrow_arrays.append(
pyarrow.array(
dataframe[bq_field.name], type=bq_to_arrow_data_type(bq_field)
)
)
arrow_names.append(bq_field.name)
arrow_arrays.append(bq_to_arrow_array(dataframe[bq_field.name], bq_field))

arrow_table = pyarrow.Table.from_arrays(arrow_arrays, names=column_names)
arrow_table = pyarrow.Table.from_arrays(arrow_arrays, names=arrow_names)
pyarrow.parquet.write_table(arrow_table, filepath)
11 changes: 10 additions & 1 deletion bigquery/tests/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def test_load_table_from_dataframe_w_nulls(self):
See: https://github.com/googleapis/google-cloud-python/issues/7370
"""
# Schema with all scalar types.
table_schema = (
scalars_schema = (
bigquery.SchemaField("bool_col", "BOOLEAN"),
bigquery.SchemaField("bytes_col", "BYTES"),
bigquery.SchemaField("date_col", "DATE"),
Expand All @@ -648,6 +648,15 @@ def test_load_table_from_dataframe_w_nulls(self):
bigquery.SchemaField("time_col", "TIME"),
bigquery.SchemaField("ts_col", "TIMESTAMP"),
)
table_schema = scalars_schema + (
# TODO: Array columns can't be read due to NULLABLE versus REPEATED
# mode mismatch. See:
# https://issuetracker.google.com/133415569#comment3
# bigquery.SchemaField("array_col", "INTEGER", mode="REPEATED"),
# TODO: Support writing StructArrays to Parquet. See:
# https://jira.apache.org/jira/browse/ARROW-2587
# bigquery.SchemaField("struct_col", "RECORD", fields=scalars_schema),
)
num_rows = 100
nulls = [None] * num_rows
dataframe = pandas.DataFrame(
Expand Down
191 changes: 169 additions & 22 deletions bigquery/tests/unit/test__pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,34 +114,181 @@ def test_all_():
("DATETIME", "NULLABLE", is_datetime),
("GEOGRAPHY", "NULLABLE", pyarrow.types.is_string),
("UNKNOWN_TYPE", "NULLABLE", is_none),
# TODO: Use pyarrow.struct(fields) for record (struct) fields.
("RECORD", "NULLABLE", is_none),
("STRUCT", "NULLABLE", is_none),
# TODO: Use pyarrow.list_(item_type) for repeated (array) fields.
("STRING", "REPEATED", is_none),
("STRING", "repeated", is_none),
("STRING", "RePeAtEd", is_none),
("BYTES", "REPEATED", is_none),
("INTEGER", "REPEATED", is_none),
("INT64", "REPEATED", is_none),
("FLOAT", "REPEATED", is_none),
("FLOAT64", "REPEATED", is_none),
("NUMERIC", "REPEATED", is_none),
("BOOLEAN", "REPEATED", is_none),
("BOOL", "REPEATED", is_none),
("TIMESTAMP", "REPEATED", is_none),
("DATE", "REPEATED", is_none),
("TIME", "REPEATED", is_none),
("DATETIME", "REPEATED", is_none),
("GEOGRAPHY", "REPEATED", is_none),
# Use pyarrow.list_(item_type) for repeated (array) fields.
(
"STRING",
"REPEATED",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_string(type_.value_type),
),
),
(
"STRING",
"repeated",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_string(type_.value_type),
),
),
(
"STRING",
"RePeAtEd",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_string(type_.value_type),
),
),
(
"BYTES",
"REPEATED",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_binary(type_.value_type),
),
),
(
"INTEGER",
"REPEATED",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_int64(type_.value_type),
),
),
(
"INT64",
"REPEATED",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_int64(type_.value_type),
),
),
(
"FLOAT",
"REPEATED",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_float64(type_.value_type),
),
),
(
"FLOAT64",
"REPEATED",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_float64(type_.value_type),
),
),
(
"NUMERIC",
"REPEATED",
all_(pyarrow.types.is_list, lambda type_: is_numeric(type_.value_type)),
),
(
"BOOLEAN",
"REPEATED",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_boolean(type_.value_type),
),
),
(
"BOOL",
"REPEATED",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_boolean(type_.value_type),
),
),
(
"TIMESTAMP",
"REPEATED",
all_(pyarrow.types.is_list, lambda type_: is_timestamp(type_.value_type)),
),
(
"DATE",
"REPEATED",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_date32(type_.value_type),
),
),
(
"TIME",
"REPEATED",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_time64(type_.value_type),
),
),
(
"DATETIME",
"REPEATED",
all_(pyarrow.types.is_list, lambda type_: is_datetime(type_.value_type)),
),
(
"GEOGRAPHY",
"REPEATED",
all_(
pyarrow.types.is_list,
lambda type_: pyarrow.types.is_string(type_.value_type),
),
),
("RECORD", "REPEATED", is_none),
("UNKNOWN_TYPE", "REPEATED", is_none),
],
)
@pytest.mark.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_bq_to_arrow_data_type(module_under_test, bq_type, bq_mode, is_correct_type):
field = schema.SchemaField("ignored_name", bq_type, mode=bq_mode)
got = module_under_test.bq_to_arrow_data_type(field)
assert is_correct_type(got)
actual = module_under_test.bq_to_arrow_data_type(field)
assert is_correct_type(actual)


@pytest.mark.parametrize(
"bq_type", [("RECORD",), ("record",), ("STRUCT",), ("struct",)]
)
@pytest.mark.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_bq_to_arrow_data_type_w_struct(module_under_test, bq_type):
fields = (
schema.SchemaField("field01", "STRING"),
schema.SchemaField("field02", "BYTES"),
schema.SchemaField("field03", "INTEGER"),
schema.SchemaField("field04", "INT64"),
schema.SchemaField("field05", "FLOAT"),
schema.SchemaField("field06", "FLOAT64"),
schema.SchemaField("field07", "NUMERIC"),
schema.SchemaField("field08", "BOOLEAN"),
schema.SchemaField("field09", "BOOL"),
schema.SchemaField("field10", "TIMESTAMP"),
schema.SchemaField("field11", "DATE"),
schema.SchemaField("field12", "TIME"),
schema.SchemaField("field13", "DATETIME"),
schema.SchemaField("field14", "GEOGRAPHY"),
)
field = schema.SchemaField("ignored_name", "RECORD", mode="NULLABLE", fields=fields)
actual = module_under_test.bq_to_arrow_data_type(field)
expected = pyarrow.struct(
(
pyarrow.field("field01", pyarrow.string()),
pyarrow.field("field02", pyarrow.binary()),
pyarrow.field("field03", pyarrow.int64()),
pyarrow.field("field04", pyarrow.int64()),
pyarrow.field("field05", pyarrow.float64()),
pyarrow.field("field06", pyarrow.float64()),
pyarrow.field("field07", module_under_test.pyarrow_numeric()),
pyarrow.field("field08", pyarrow.bool_()),
pyarrow.field("field09", pyarrow.bool_()),
pyarrow.field("field10", module_under_test.pyarrow_timestamp()),
pyarrow.field("field11", pyarrow.date32()),
pyarrow.field("field12", module_under_test.pyarrow_time()),
pyarrow.field("field13", module_under_test.pyarrow_datetime()),
pyarrow.field("field14", pyarrow.string()),
)
)
assert pyarrow.types.is_struct(actual)
assert actual.num_children == len(fields)
assert actual.equals(expected)


@pytest.mark.skipIf(pandas is None, "Requires `pandas`")
Expand Down

0 comments on commit aa38e42

Please sign in to comment.