Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BigQuery: Autofetch table schema on load if not provided #9108

Merged
merged 7 commits into from
Sep 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def list_columns_and_indexes(dataframe):
"""Return all index and column names with dtypes.

Returns:
Sequence[Tuple[dtype, str]]:
Sequence[Tuple[str, dtype]]:
Returns a sorted list of indexes and column names with
corresponding dtypes. If an index is missing a name or has the
same name as a column, the index is omitted.
Expand Down
21 changes: 21 additions & 0 deletions bigquery/google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,27 @@ def load_table_from_dataframe(
if location is None:
location = self.location

# If table schema is not provided, we try to fetch the existing table
# schema, and check if dataframe schema is compatible with it - except
# for WRITE_TRUNCATE jobs, the existing schema does not matter then.
if (
not job_config.schema
and job_config.write_disposition != job.WriteDisposition.WRITE_TRUNCATE
):
try:
table = self.get_table(destination)
except google.api_core.exceptions.NotFound:
table = None
else:
columns_and_indexes = frozenset(
name
for name, _ in _pandas_helpers.list_columns_and_indexes(dataframe)
)
# schema fields not present in the dataframe are not needed
job_config.schema = [
field for field in table.schema if field.name in columns_and_indexes
]

job_config.schema = _pandas_helpers.dataframe_to_bq_schema(
dataframe, job_config.schema
)
Expand Down
147 changes: 141 additions & 6 deletions bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import gzip
import io
import json
import operator
import unittest
import warnings

Expand Down Expand Up @@ -5279,15 +5280,23 @@ def test_load_table_from_file_bad_mode(self):
def test_load_table_from_dataframe(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file:
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(dataframe, self.TABLE_REF)

load_table_from_file.assert_called_once_with(
Expand All @@ -5314,15 +5323,23 @@ def test_load_table_from_dataframe(self):
def test_load_table_from_dataframe_w_client_location(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client(location=self.LOCATION)
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file:
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(dataframe, self.TABLE_REF)

load_table_from_file.assert_called_once_with(
Expand All @@ -5349,20 +5366,33 @@ def test_load_table_from_dataframe_w_client_location(self):
def test_load_table_from_dataframe_w_custom_job_config(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)
job_config = job.LoadJobConfig()
job_config = job.LoadJobConfig(
write_disposition=job.WriteDisposition.WRITE_TRUNCATE
)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file:
with load_patch as load_table_from_file, get_table_patch as get_table:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION
)

# no need to fetch and inspect table schema for WRITE_TRUNCATE jobs
assert not get_table.called

load_table_from_file.assert_called_once_with(
client,
mock.ANY,
Expand All @@ -5378,6 +5408,7 @@ def test_load_table_from_dataframe_w_custom_job_config(self):

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config.source_format == job.SourceFormat.PARQUET
assert sent_config.write_disposition == job.WriteDisposition.WRITE_TRUNCATE

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
Expand Down Expand Up @@ -5421,7 +5452,12 @@ def test_load_table_from_dataframe_w_automatic_schema(self):
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)

with load_patch as load_table_from_file:
get_table_patch = mock.patch(
plamut marked this conversation as resolved.
Show resolved Hide resolved
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
side_effect=google.api_core.exceptions.NotFound("Table not found"),
)
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, location=self.LOCATION
)
Expand Down Expand Up @@ -5449,6 +5485,100 @@ def test_load_table_from_dataframe_w_automatic_schema(self):
SchemaField("ts_col", "TIMESTAMP"),
)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_w_index_and_auto_schema(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
df_data = collections.OrderedDict(
[("int_col", [10, 20, 30]), ("float_col", [1.0, 2.0, 3.0])]
)
dataframe = pandas.DataFrame(
df_data,
index=pandas.Index(name="unique_name", data=["one", "two", "three"]),
)

load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[
SchemaField("int_col", "INTEGER"),
SchemaField("float_col", "FLOAT"),
SchemaField("unique_name", "STRING"),
]
),
)
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, location=self.LOCATION
)

load_table_from_file.assert_called_once_with(
client,
mock.ANY,
self.TABLE_REF,
num_retries=_DEFAULT_NUM_RETRIES,
rewind=True,
job_id=mock.ANY,
job_id_prefix=None,
location=self.LOCATION,
project=None,
job_config=mock.ANY,
)

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config.source_format == job.SourceFormat.PARQUET

sent_schema = sorted(sent_config.schema, key=operator.attrgetter("name"))
expected_sent_schema = [
SchemaField("float_col", "FLOAT"),
SchemaField("int_col", "INTEGER"),
SchemaField("unique_name", "STRING"),
]
assert sent_schema == expected_sent_schema

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_unknown_table(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES

client = self._make_client()
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
side_effect=google.api_core.exceptions.NotFound("Table not found"),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file, get_table_patch:
# there should be no error
client.load_table_from_dataframe(dataframe, self.TABLE_REF)

load_table_from_file.assert_called_once_with(
client,
mock.ANY,
self.TABLE_REF,
num_retries=_DEFAULT_NUM_RETRIES,
rewind=True,
job_id=mock.ANY,
job_id_prefix=None,
location=None,
project=None,
job_config=mock.ANY,
)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_struct_fields_error(self):
Expand Down Expand Up @@ -5741,6 +5871,11 @@ def test_load_table_from_dataframe_wo_pyarrow_custom_compression(self):
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
side_effect=google.api_core.exceptions.NotFound("Table not found"),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
Expand All @@ -5749,7 +5884,7 @@ def test_load_table_from_dataframe_wo_pyarrow_custom_compression(self):
dataframe, "to_parquet", wraps=dataframe.to_parquet
)

with load_patch, pyarrow_patch, to_parquet_patch as to_parquet_spy:
with load_patch, get_table_patch, pyarrow_patch, to_parquet_patch as to_parquet_spy:
client.load_table_from_dataframe(
dataframe,
self.TABLE_REF,
Expand Down