Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make total_rows available on RowIterator before iteration #7622

Merged
merged 3 commits into from
Apr 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions bigquery/docs/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2246,8 +2246,7 @@ def test_client_query_total_rows(client, capsys):
location="US",
) # API request - starts the query

results = query_job.result() # Waits for query to complete.
next(iter(results)) # Fetch the first page of results, which contains total_rows.
results = query_job.result() # Wait for query to complete.
print("Got {} rows.".format(results.total_rows))
# [END bigquery_query_total_rows]

Expand Down
1 change: 1 addition & 0 deletions bigquery/google/cloud/bigquery/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2808,6 +2808,7 @@ def result(self, timeout=None, retry=DEFAULT_RETRY):
schema = self._query_results.schema
dest_table_ref = self.destination
dest_table = Table(dest_table_ref, schema=schema)
dest_table._properties["numRows"] = self._query_results.total_rows
return self._client.list_rows(dest_table, retry=retry)

def to_dataframe(self, bqstorage_client=None, dtypes=None, progress_bar_type=None):
Expand Down
8 changes: 4 additions & 4 deletions bigquery/google/cloud/bigquery/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,13 +1267,13 @@ def __init__(
page_start=_rows_page_start,
next_token="pageToken",
)
self._schema = schema
self._field_to_index = _helpers._field_to_index_mapping(schema)
self._total_rows = None
self._page_size = page_size
self._table = table
self._selected_fields = selected_fields
self._project = client.project
self._schema = schema
self._selected_fields = selected_fields
self._table = table
self._total_rows = getattr(table, "num_rows", None)

def _get_next_page_response(self):
"""Requests the next page from the path provided.
Expand Down
27 changes: 15 additions & 12 deletions bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4359,18 +4359,21 @@ def test_list_rows_empty_table(self):
client._connection = _make_connection(response, response)

# Table that has no schema because it's an empty table.
rows = tuple(
client.list_rows(
# Test with using a string for the table ID.
"{}.{}.{}".format(
self.TABLE_REF.project,
self.TABLE_REF.dataset_id,
self.TABLE_REF.table_id,
),
selected_fields=[],
)
rows = client.list_rows(
# Test with using a string for the table ID.
"{}.{}.{}".format(
self.TABLE_REF.project,
self.TABLE_REF.dataset_id,
self.TABLE_REF.table_id,
),
selected_fields=[],
)
self.assertEqual(rows, ())

# When a table reference / string and selected_fields is provided,
# total_rows can't be populated until iteration starts.
self.assertIsNone(rows.total_rows)
self.assertEqual(tuple(rows), ())
self.assertEqual(rows.total_rows, 0)

def test_list_rows_query_params(self):
from google.cloud.bigquery.table import Table, SchemaField
Expand Down Expand Up @@ -4573,7 +4576,7 @@ def test_list_rows_with_missing_schema(self):

conn.api_request.assert_called_once_with(method="GET", path=table_path)
conn.api_request.reset_mock()
self.assertIsNone(row_iter.total_rows, msg=repr(table))
self.assertEqual(row_iter.total_rows, 2, msg=repr(table))

rows = list(row_iter)
conn.api_request.assert_called_once_with(
Expand Down
25 changes: 23 additions & 2 deletions bigquery/tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4012,21 +4012,41 @@ def test_estimated_bytes_processed(self):
self.assertEqual(job.estimated_bytes_processed, est_bytes)

def test_result(self):
from google.cloud.bigquery.table import RowIterator

query_resource = {
"jobComplete": True,
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "2",
}
connection = _make_connection(query_resource, query_resource)
tabledata_resource = {
# Explicitly set totalRows to be different from the query response.
# to test update during iteration.
"totalRows": "1",
"pageToken": None,
"rows": [{"f": [{"v": "abc"}]}],
}
connection = _make_connection(query_resource, tabledata_resource)
client = _make_client(self.PROJECT, connection=connection)
resource = self._make_resource(ended=True)
job = self._get_target_class().from_api_repr(resource, client)

result = job.result()

self.assertEqual(list(result), [])
self.assertIsInstance(result, RowIterator)
self.assertEqual(result.total_rows, 2)

rows = list(result)
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0].col1, "abc")
# Test that the total_rows property has changed during iteration, based
# on the response from tabledata.list.
self.assertEqual(result.total_rows, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to test explicitly the case where the initial query result resource has a different number of rows than the count got by iteration? If so, then a comment above stating that would be helpful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's intentional. You're right that it looks wrong. I suspect it'll be quite rare in practice, but it can happen, especially when append query jobs are involved. Added comments in 99441d7.


def test_result_w_empty_schema(self):
from google.cloud.bigquery.table import _EmptyRowIterator

# Destination table may have no schema for some DDL and DML queries.
query_resource = {
"jobComplete": True,
Expand All @@ -4040,6 +4060,7 @@ def test_result_w_empty_schema(self):

result = job.result()

self.assertIsInstance(result, _EmptyRowIterator)
self.assertEqual(list(result), [])

def test_result_invokes_begins(self):
Expand Down
Loading