Skip to content

Commit

Permalink
Make total_rows available on RowIterator before iteration (#7622)
Browse files Browse the repository at this point in the history
* Make total_rows available on RowIterator before iteration

After running a query, the total number of rows is available from the
call to the getQueryResults API. This commit plumbs the total rows
through to the faux Table created in QueryJob.results and then on
through to the RowIterator created by list_rows.

* Simplify RowIterator constructor. Add test comments.

Use getattr instead of protecting with hasattr in the RowIterator
constructor.

Add comments about intentionally conflicting values for total_rows.
  • Loading branch information
tswast authored Apr 16, 2019
1 parent d8212f2 commit 0bccf6c
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 57 deletions.
3 changes: 1 addition & 2 deletions bigquery/docs/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,8 +1908,7 @@ def test_client_query_total_rows(client, capsys):
location="US",
) # API request - starts the query

results = query_job.result() # Waits for query to complete.
next(iter(results)) # Fetch the first page of results, which contains total_rows.
results = query_job.result() # Wait for query to complete.
print("Got {} rows.".format(results.total_rows))
# [END bigquery_query_total_rows]

Expand Down
1 change: 1 addition & 0 deletions bigquery/google/cloud/bigquery/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2835,6 +2835,7 @@ def result(self, timeout=None, retry=DEFAULT_RETRY):
schema = self._query_results.schema
dest_table_ref = self.destination
dest_table = Table(dest_table_ref, schema=schema)
dest_table._properties["numRows"] = self._query_results.total_rows
return self._client.list_rows(dest_table, retry=retry)

def to_dataframe(self, bqstorage_client=None, dtypes=None, progress_bar_type=None):
Expand Down
8 changes: 4 additions & 4 deletions bigquery/google/cloud/bigquery/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,13 +1308,13 @@ def __init__(
page_start=_rows_page_start,
next_token="pageToken",
)
self._schema = schema
self._field_to_index = _helpers._field_to_index_mapping(schema)
self._total_rows = None
self._page_size = page_size
self._table = table
self._selected_fields = selected_fields
self._project = client.project
self._schema = schema
self._selected_fields = selected_fields
self._table = table
self._total_rows = getattr(table, "num_rows", None)

def _get_next_page_response(self):
"""Requests the next page from the path provided.
Expand Down
27 changes: 15 additions & 12 deletions bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4359,18 +4359,21 @@ def test_list_rows_empty_table(self):
client._connection = _make_connection(response, response)

# Table that has no schema because it's an empty table.
rows = tuple(
client.list_rows(
# Test with using a string for the table ID.
"{}.{}.{}".format(
self.TABLE_REF.project,
self.TABLE_REF.dataset_id,
self.TABLE_REF.table_id,
),
selected_fields=[],
)
rows = client.list_rows(
# Test with using a string for the table ID.
"{}.{}.{}".format(
self.TABLE_REF.project,
self.TABLE_REF.dataset_id,
self.TABLE_REF.table_id,
),
selected_fields=[],
)
self.assertEqual(rows, ())

# When a table reference / string and selected_fields is provided,
# total_rows can't be populated until iteration starts.
self.assertIsNone(rows.total_rows)
self.assertEqual(tuple(rows), ())
self.assertEqual(rows.total_rows, 0)

def test_list_rows_query_params(self):
from google.cloud.bigquery.table import Table, SchemaField
Expand Down Expand Up @@ -4573,7 +4576,7 @@ def test_list_rows_with_missing_schema(self):

conn.api_request.assert_called_once_with(method="GET", path=table_path)
conn.api_request.reset_mock()
self.assertIsNone(row_iter.total_rows, msg=repr(table))
self.assertEqual(row_iter.total_rows, 2, msg=repr(table))

rows = list(row_iter)
conn.api_request.assert_called_once_with(
Expand Down
25 changes: 23 additions & 2 deletions bigquery/tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4024,21 +4024,41 @@ def test_estimated_bytes_processed(self):
self.assertEqual(job.estimated_bytes_processed, est_bytes)

def test_result(self):
from google.cloud.bigquery.table import RowIterator

query_resource = {
"jobComplete": True,
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "2",
}
connection = _make_connection(query_resource, query_resource)
tabledata_resource = {
# Explicitly set totalRows to be different from the query response.
# to test update during iteration.
"totalRows": "1",
"pageToken": None,
"rows": [{"f": [{"v": "abc"}]}],
}
connection = _make_connection(query_resource, tabledata_resource)
client = _make_client(self.PROJECT, connection=connection)
resource = self._make_resource(ended=True)
job = self._get_target_class().from_api_repr(resource, client)

result = job.result()

self.assertEqual(list(result), [])
self.assertIsInstance(result, RowIterator)
self.assertEqual(result.total_rows, 2)

rows = list(result)
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0].col1, "abc")
# Test that the total_rows property has changed during iteration, based
# on the response from tabledata.list.
self.assertEqual(result.total_rows, 1)

def test_result_w_empty_schema(self):
from google.cloud.bigquery.table import _EmptyRowIterator

# Destination table may have no schema for some DDL and DML queries.
query_resource = {
"jobComplete": True,
Expand All @@ -4052,6 +4072,7 @@ def test_result_w_empty_schema(self):

result = job.result()

self.assertIsInstance(result, _EmptyRowIterator)
self.assertEqual(list(result), [])

def test_result_invokes_begins(self):
Expand Down
Loading

0 comments on commit 0bccf6c

Please sign in to comment.