diff --git a/redash/query_runner/big_query.py b/redash/query_runner/big_query.py index 455b23d585..49895b0b10 100644 --- a/redash/query_runner/big_query.py +++ b/redash/query_runner/big_query.py @@ -2,6 +2,7 @@ import logging import sys import time +import operator from base64 import b64decode import httplib2 @@ -121,6 +122,10 @@ class BigQuery(BaseQueryRunner): "default": "_v", "info": "This string will be used to toggle visibility of tables in the schema browser when editing a query in order to remove non-useful tables from sight." }, + 'samples': { + 'type': 'boolean', + 'title': 'Show Data Samples' + }, } @classmethod @@ -246,22 +251,118 @@ def _get_query_result(self, jobs, query): def _get_columns_schema(self, table_data): columns = [] + metadata = [] for column in table_data.get('schema', {}).get('fields', []): - columns.extend(self._get_columns_schema_column(column)) + metadatum = self._get_column_metadata(column) + metadata.extend(metadatum) + columns.extend(map(operator.itemgetter('name'), metadatum)) project_id = self._get_project_id() table_name = table_data['id'].replace("%s:" % project_id, "") - return {'name': table_name, 'columns': columns} + return {'name': table_name, 'columns': columns, 'metadata': metadata} + + def _get_column_metadata(self, column): + metadata = [] - def _get_columns_schema_column(self, column): - columns = [] if column['type'] == 'RECORD': for field in column['fields']: - columns.append(u"{}.{}".format(column['name'], field['name'])) + field_name = u"{}.{}".format(column['name'], field['name']) + metadata.append({'name': field_name, 'type': field['type']}) else: - columns.append(column['name']) + metadata.append({'name': column['name'], 'type': column['type']}) + + return metadata + + def _columns_and_samples_to_dict(self, schema, samples): + samples_dict = {} + if not samples: + return samples_dict + + # If a sample exists, its shape/length should be analogous to + # the schema provided (i.e their lengths should match up) + for i, column in enumerate(schema): + if column['type'] == 'RECORD': + if column.get('mode', None) == 'REPEATED': + # Repeated fields have multiple samples of the same format. + # We only need to show the first one as an example. + associated_sample = [] if len(samples[i]) == 0 else samples[i][0] + else: + associated_sample = samples[i] or [] + + for j, field in enumerate(column['fields']): + field_name = u"{}.{}".format(column['name'], field['name']) + samples_dict[field_name] = None + if len(associated_sample) > 0: + samples_dict[field_name] = associated_sample[j] + else: + samples_dict[column['name']] = samples[i] + + return samples_dict + + + def _flatten_samples(self, samples): + samples_list = [] + for field in samples: + value = field['v'] + if isinstance(value, dict): + samples_list.append( + self._flatten_samples(value.get('f', [])) + ) + elif isinstance(value, list): + samples_list.append( + self._flatten_samples(value) + ) + else: + samples_list.append(value) + + return samples_list + + def get_table_sample(self, table_name): + if not self.configuration.get('loadSchema', False): + return {} + + service = self._get_bigquery_service() + project_id = self._get_project_id() - return columns + dataset_id, table_id = table_name.split('.', 1) + + try: + # NOTE: the `sample_response` is limited by `maxResults` here. + # Without this limit, the response would be very large and require + # pagination using `nextPageToken`. + sample_response = service.tabledata().list( + projectId=project_id, + datasetId=dataset_id, + tableId=table_id, + fields="rows", + maxResults=1 + ).execute() + schema_response = service.tables().get( + projectId=project_id, + datasetId=dataset_id, + tableId=table_id, + fields="schema,id", + ).execute() + table_rows = sample_response.get('rows', []) + + if len(table_rows) == 0: + samples = [] + else: + samples = table_rows[0].get('f', []) + + schema = schema_response.get('schema', {}).get('fields', []) + columns = self._get_columns_schema(schema_response).get('columns', []) + + flattened_samples = self._flatten_samples(samples) + samples_dict = self._columns_and_samples_to_dict(schema, flattened_samples) + return samples_dict + except HttpError as http_error: + logger.exception( + "Error communicating with server for sample for table %s: %s", + table_name, + http_error + ) + return {} def get_schema(self, get_stats=False): if not self.configuration.get('loadSchema', False): diff --git a/tests/query_runner/test_bigquery.py b/tests/query_runner/test_bigquery.py new file mode 100644 index 0000000000..cbfe9322f7 --- /dev/null +++ b/tests/query_runner/test_bigquery.py @@ -0,0 +1,76 @@ +from mock import patch +from tests import BaseTestCase + +from redash.query_runner.big_query import BigQuery + + +class TestBigQuery(BaseTestCase): + + def test_get_table_sample_returns_expected_result(self): + SAMPLES_RESPONSE = { + 'rows': [ + {'f': [ + { + 'v': '2017-10-28' + }, { + 'v': '2019-03-28T18:57:04.485091' + }, { + 'v': '3341' + }, { + 'v': '2451' + }, { + 'v': 'Iran' + } + ]} + ] + } + + SCHEMA_RESPONSE = { + 'id': 'project:dataset.table', + 'schema': { + 'fields': [{ + 'type': 'DATE', + 'name': 'submission_date', + 'mode': 'NULLABLE' + }, { + 'type': 'DATETIME', + 'name': 'generated_time', + 'mode': 'NULLABLE' + }, { + 'type': 'INTEGER', + 'name': 'mau', + 'mode': 'NULLABLE' + }, { + 'type': 'INTEGER', + 'name': 'wau', + 'mode': 'NULLABLE' + }, { + 'type': 'STRING', + 'name': 'country', + 'mode': 'NULLABLE' + }] + } + } + + EXPECTED_SAMPLES_DICT = { + 'submission_date': '2017-10-28', + 'country': 'Iran', + 'wau': '2451', + 'mau': '3341', + 'generated_time': '2019-03-28T18:57:04.485091' + } + + with patch.object(BigQuery, '_get_bigquery_service') as get_bq_service: + tabledata_list = get_bq_service.return_value.tabledata.return_value.list + tabledata_list.return_value.execute.return_value = SAMPLES_RESPONSE + + tables_get = get_bq_service.return_value.tables.return_value.get + tables_get.return_value.execute.return_value = SCHEMA_RESPONSE + + query_runner = BigQuery({ + 'loadSchema': True, + 'projectId': 'test_project' + }) + table_sample = query_runner.get_table_sample("dataset.table") + + self.assertEqual(table_sample, EXPECTED_SAMPLES_DICT)