From 73b46b9629e85e96e7ba949f9ee2936a552cd9ad Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 28 Mar 2019 17:05:15 -0700 Subject: [PATCH 1/2] Make total_rows available on RowIterator before iteration After running a query, the total number of rows is available from the call to the getQueryResults API. This commit plumbs the total rows through to the faux Table created in QueryJob.results and then on through to the RowIterator created by list_rows. --- bigquery/docs/snippets.py | 3 +- bigquery/google/cloud/bigquery/job.py | 1 + bigquery/google/cloud/bigquery/table.py | 4 + bigquery/tests/unit/test_client.py | 27 ++++--- bigquery/tests/unit/test_job.py | 21 +++++- bigquery/tests/unit/test_table.py | 98 +++++++++++++++---------- 6 files changed, 101 insertions(+), 53 deletions(-) diff --git a/bigquery/docs/snippets.py b/bigquery/docs/snippets.py index 00569c40af18..5d0847e5b892 100644 --- a/bigquery/docs/snippets.py +++ b/bigquery/docs/snippets.py @@ -2246,8 +2246,7 @@ def test_client_query_total_rows(client, capsys): location="US", ) # API request - starts the query - results = query_job.result() # Waits for query to complete. - next(iter(results)) # Fetch the first page of results, which contains total_rows. + results = query_job.result() # Wait for query to complete. print("Got {} rows.".format(results.total_rows)) # [END bigquery_query_total_rows] diff --git a/bigquery/google/cloud/bigquery/job.py b/bigquery/google/cloud/bigquery/job.py index 94a2290cc29e..75cb57ad2894 100644 --- a/bigquery/google/cloud/bigquery/job.py +++ b/bigquery/google/cloud/bigquery/job.py @@ -2808,6 +2808,7 @@ def result(self, timeout=None, retry=DEFAULT_RETRY): schema = self._query_results.schema dest_table_ref = self.destination dest_table = Table(dest_table_ref, schema=schema) + dest_table._properties["numRows"] = self._query_results.total_rows return self._client.list_rows(dest_table, retry=retry) def to_dataframe(self, bqstorage_client=None, dtypes=None, progress_bar_type=None): diff --git a/bigquery/google/cloud/bigquery/table.py b/bigquery/google/cloud/bigquery/table.py index ab22407eff1a..7fd23b4ac984 100644 --- a/bigquery/google/cloud/bigquery/table.py +++ b/bigquery/google/cloud/bigquery/table.py @@ -1300,7 +1300,11 @@ def __init__( ) self._schema = schema self._field_to_index = _helpers._field_to_index_mapping(schema) + self._total_rows = None + if table is not None and hasattr(table, "num_rows"): + self._total_rows = table.num_rows + self._page_size = page_size self._table = table self._selected_fields = selected_fields diff --git a/bigquery/tests/unit/test_client.py b/bigquery/tests/unit/test_client.py index 671bbdf29778..73125eaddd33 100644 --- a/bigquery/tests/unit/test_client.py +++ b/bigquery/tests/unit/test_client.py @@ -4115,18 +4115,21 @@ def test_list_rows_empty_table(self): client._connection = _make_connection(response, response) # Table that has no schema because it's an empty table. - rows = tuple( - client.list_rows( - # Test with using a string for the table ID. - "{}.{}.{}".format( - self.TABLE_REF.project, - self.TABLE_REF.dataset_id, - self.TABLE_REF.table_id, - ), - selected_fields=[], - ) + rows = client.list_rows( + # Test with using a string for the table ID. + "{}.{}.{}".format( + self.TABLE_REF.project, + self.TABLE_REF.dataset_id, + self.TABLE_REF.table_id, + ), + selected_fields=[], ) - self.assertEqual(rows, ()) + + # When a table reference / string and selected_fields is provided, + # total_rows can't be populated until iteration starts. + self.assertIsNone(rows.total_rows) + self.assertEqual(tuple(rows), ()) + self.assertEqual(rows.total_rows, 0) def test_list_rows_query_params(self): from google.cloud.bigquery.table import Table, SchemaField @@ -4329,7 +4332,7 @@ def test_list_rows_with_missing_schema(self): conn.api_request.assert_called_once_with(method="GET", path=table_path) conn.api_request.reset_mock() - self.assertIsNone(row_iter.total_rows, msg=repr(table)) + self.assertEqual(row_iter.total_rows, 2, msg=repr(table)) rows = list(row_iter) conn.api_request.assert_called_once_with( diff --git a/bigquery/tests/unit/test_job.py b/bigquery/tests/unit/test_job.py index a42d9ffc311c..baf9ef67fe8b 100644 --- a/bigquery/tests/unit/test_job.py +++ b/bigquery/tests/unit/test_job.py @@ -4012,21 +4012,37 @@ def test_estimated_bytes_processed(self): self.assertEqual(job.estimated_bytes_processed, est_bytes) def test_result(self): + from google.cloud.bigquery.table import RowIterator + query_resource = { "jobComplete": True, "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, + "totalRows": "2", } - connection = _make_connection(query_resource, query_resource) + tabledata_resource = { + "totalRows": "1", + "pageToken": None, + "rows": [{"f": [{"v": "abc"}]}], + } + connection = _make_connection(query_resource, tabledata_resource) client = _make_client(self.PROJECT, connection=connection) resource = self._make_resource(ended=True) job = self._get_target_class().from_api_repr(resource, client) result = job.result() - self.assertEqual(list(result), []) + self.assertIsInstance(result, RowIterator) + self.assertEqual(result.total_rows, 2) + + rows = list(result) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0].col1, "abc") + self.assertEqual(result.total_rows, 1) def test_result_w_empty_schema(self): + from google.cloud.bigquery.table import _EmptyRowIterator + # Destination table may have no schema for some DDL and DML queries. query_resource = { "jobComplete": True, @@ -4040,6 +4056,7 @@ def test_result_w_empty_schema(self): result = job.result() + self.assertIsInstance(result, _EmptyRowIterator) self.assertEqual(list(result), []) def test_result_invokes_begins(self): diff --git a/bigquery/tests/unit/test_table.py b/bigquery/tests/unit/test_table.py index 4500856ec2a4..44eae538b1bd 100644 --- a/bigquery/tests/unit/test_table.py +++ b/bigquery/tests/unit/test_table.py @@ -1282,51 +1282,85 @@ def test_row(self): class Test_EmptyRowIterator(unittest.TestCase): - @mock.patch("google.cloud.bigquery.table.pandas", new=None) - def test_to_dataframe_error_if_pandas_is_none(self): + def _make_one(self): from google.cloud.bigquery.table import _EmptyRowIterator - row_iterator = _EmptyRowIterator() + return _EmptyRowIterator() + + def test_total_rows_eq_zero(self): + row_iterator = self._make_one() + self.assertEqual(row_iterator.total_rows, 0) + + @mock.patch("google.cloud.bigquery.table.pandas", new=None) + def test_to_dataframe_error_if_pandas_is_none(self): + row_iterator = self._make_one() with self.assertRaises(ValueError): row_iterator.to_dataframe() @unittest.skipIf(pandas is None, "Requires `pandas`") def test_to_dataframe(self): - from google.cloud.bigquery.table import _EmptyRowIterator - - row_iterator = _EmptyRowIterator() + row_iterator = self._make_one() df = row_iterator.to_dataframe() self.assertIsInstance(df, pandas.DataFrame) self.assertEqual(len(df), 0) # verify the number of rows class TestRowIterator(unittest.TestCase): - def test_constructor(self): + def _make_one( + self, client=None, api_request=None, path=None, schema=None, **kwargs + ): from google.cloud.bigquery.table import RowIterator + + if client is None: + client = _mock_client() + + if api_request is None: + api_request = mock.sentinel.api_request + + if path is None: + path = "/foo" + + if schema is None: + schema = [] + + return RowIterator(client, api_request, path, schema, **kwargs) + + def test_constructor(self): from google.cloud.bigquery.table import _item_to_row from google.cloud.bigquery.table import _rows_page_start client = _mock_client() - api_request = mock.sentinel.api_request - path = "/foo" - schema = [] - iterator = RowIterator(client, api_request, path, schema) + path = "/some/path" + iterator = self._make_one(client=client, path=path) - self.assertFalse(iterator._started) + # Objects are set without copying. self.assertIs(iterator.client, client) - self.assertEqual(iterator.path, path) self.assertIs(iterator.item_to_value, _item_to_row) + self.assertIs(iterator._page_start, _rows_page_start) + # Properties have the expect value. + self.assertEqual(iterator.extra_params, {}) self.assertEqual(iterator._items_key, "rows") self.assertIsNone(iterator.max_results) - self.assertEqual(iterator.extra_params, {}) - self.assertIs(iterator._page_start, _rows_page_start) + self.assertEqual(iterator.path, path) + self.assertFalse(iterator._started) + self.assertIsNone(iterator.total_rows) # Changing attributes. self.assertEqual(iterator.page_number, 0) self.assertIsNone(iterator.next_page_token) self.assertEqual(iterator.num_results, 0) + def test_constructor_with_table(self): + from google.cloud.bigquery.table import Table + + table = Table("proj.dset.tbl") + table._properties["numRows"] = 100 + + iterator = self._make_one(table=table) + + self.assertIs(iterator._table, table) + self.assertEqual(iterator.total_rows, 100) + def test_iterate(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1339,7 +1373,7 @@ def test_iterate(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) self.assertEqual(row_iterator.num_results, 0) rows_iter = iter(row_iterator) @@ -1358,7 +1392,6 @@ def test_iterate(self): api_request.assert_called_once_with(method="GET", path=path, query_params={}) def test_page_size(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1372,7 +1405,7 @@ def test_page_size(self): path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator( + row_iterator = self._make_one( _mock_client(), api_request, path, schema, page_size=4 ) row_iterator._get_next_page_response() @@ -1385,7 +1418,6 @@ def test_page_size(self): @unittest.skipIf(pandas is None, "Requires `pandas`") def test_to_dataframe(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1400,7 +1432,7 @@ def test_to_dataframe(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe() @@ -1418,7 +1450,6 @@ def test_to_dataframe(self): def test_to_dataframe_progress_bar( self, tqdm_mock, tqdm_notebook_mock, tqdm_gui_mock ): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1441,7 +1472,7 @@ def test_to_dataframe_progress_bar( ) for progress_bar_type, progress_bar_mock in progress_bars: - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe(progress_bar_type=progress_bar_type) progress_bar_mock.assert_called() @@ -1451,7 +1482,6 @@ def test_to_dataframe_progress_bar( @unittest.skipIf(pandas is None, "Requires `pandas`") @mock.patch("google.cloud.bigquery.table.tqdm", new=None) def test_to_dataframe_no_tqdm_no_progress_bar(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1466,7 +1496,7 @@ def test_to_dataframe_no_tqdm_no_progress_bar(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) with warnings.catch_warnings(record=True) as warned: df = row_iterator.to_dataframe() @@ -1477,7 +1507,6 @@ def test_to_dataframe_no_tqdm_no_progress_bar(self): @unittest.skipIf(pandas is None, "Requires `pandas`") @mock.patch("google.cloud.bigquery.table.tqdm", new=None) def test_to_dataframe_no_tqdm(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1492,7 +1521,7 @@ def test_to_dataframe_no_tqdm(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) with warnings.catch_warnings(record=True) as warned: df = row_iterator.to_dataframe(progress_bar_type="tqdm") @@ -1511,7 +1540,6 @@ def test_to_dataframe_no_tqdm(self): @mock.patch("tqdm.tqdm_notebook", new=None) # will raise TypeError on call @mock.patch("tqdm.tqdm", new=None) # will raise TypeError on call def test_to_dataframe_tqdm_error(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1528,14 +1556,13 @@ def test_to_dataframe_tqdm_error(self): for progress_bar_type in ("tqdm", "tqdm_notebook", "tqdm_gui"): api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe(progress_bar_type=progress_bar_type) self.assertEqual(len(df), 4) # all should be well @unittest.skipIf(pandas is None, "Requires `pandas`") def test_to_dataframe_w_empty_results(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1544,7 +1571,7 @@ def test_to_dataframe_w_empty_results(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": []}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe() @@ -1555,7 +1582,6 @@ def test_to_dataframe_w_empty_results(self): @unittest.skipIf(pandas is None, "Requires `pandas`") def test_to_dataframe_w_various_types_nullable(self): import datetime - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1575,7 +1601,7 @@ def test_to_dataframe_w_various_types_nullable(self): rows = [{"f": [{"v": field} for field in row]} for row in row_data] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe() @@ -1596,7 +1622,6 @@ def test_to_dataframe_w_various_types_nullable(self): @unittest.skipIf(pandas is None, "Requires `pandas`") def test_to_dataframe_column_dtypes(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1616,7 +1641,7 @@ def test_to_dataframe_column_dtypes(self): rows = [{"f": [{"v": field} for field in row]} for row in row_data] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe(dtypes={"km": "float16"}) @@ -1635,7 +1660,6 @@ def test_to_dataframe_column_dtypes(self): @mock.patch("google.cloud.bigquery.table.pandas", new=None) def test_to_dataframe_error_if_pandas_is_none(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1648,7 +1672,7 @@ def test_to_dataframe_error_if_pandas_is_none(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) with self.assertRaises(ValueError): row_iterator.to_dataframe() From 99441d71ea1233b8bd8089c06d835298b9cff84d Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 1 Apr 2019 09:12:07 -0700 Subject: [PATCH 2/2] Simplify RowIterator constructor. Add test comments. Use getattr instead of protecting with hasattr in the RowIterator constructor. Add comments about intentionally conflicting values for total_rows. --- bigquery/google/cloud/bigquery/table.py | 12 ++++-------- bigquery/tests/unit/test_job.py | 4 ++++ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bigquery/google/cloud/bigquery/table.py b/bigquery/google/cloud/bigquery/table.py index 7fd23b4ac984..98021db77345 100644 --- a/bigquery/google/cloud/bigquery/table.py +++ b/bigquery/google/cloud/bigquery/table.py @@ -1298,17 +1298,13 @@ def __init__( page_start=_rows_page_start, next_token="pageToken", ) - self._schema = schema self._field_to_index = _helpers._field_to_index_mapping(schema) - - self._total_rows = None - if table is not None and hasattr(table, "num_rows"): - self._total_rows = table.num_rows - self._page_size = page_size - self._table = table - self._selected_fields = selected_fields self._project = client.project + self._schema = schema + self._selected_fields = selected_fields + self._table = table + self._total_rows = getattr(table, "num_rows", None) def _get_next_page_response(self): """Requests the next page from the path provided. diff --git a/bigquery/tests/unit/test_job.py b/bigquery/tests/unit/test_job.py index baf9ef67fe8b..01de148d5bf8 100644 --- a/bigquery/tests/unit/test_job.py +++ b/bigquery/tests/unit/test_job.py @@ -4021,6 +4021,8 @@ def test_result(self): "totalRows": "2", } tabledata_resource = { + # Explicitly set totalRows to be different from the query response. + # to test update during iteration. "totalRows": "1", "pageToken": None, "rows": [{"f": [{"v": "abc"}]}], @@ -4038,6 +4040,8 @@ def test_result(self): rows = list(result) self.assertEqual(len(rows), 1) self.assertEqual(rows[0].col1, "abc") + # Test that the total_rows property has changed during iteration, based + # on the response from tabledata.list. self.assertEqual(result.total_rows, 1) def test_result_w_empty_schema(self):