diff --git a/cloudbuild_CI.yaml b/cloudbuild_CI.yaml index 5ec4559f8..81f2cbb10 100644 --- a/cloudbuild_CI.yaml +++ b/cloudbuild_CI.yaml @@ -42,9 +42,7 @@ steps: - '--project ${PROJECT_ID}' - '--image_tag ${COMMIT_SHA}' - '--run_unit_tests' - - '--run_preprocessor_tests' - - '--run_bq_to_vcf_tests' - - '--run_all_tests' + - '--run_presubmit_tests' - '--test_name_prefix cloud-ci-' id: 'test-gcp-variant-transforms-docker' entrypoint: '/opt/gcp_variant_transforms/src/deploy_and_run_tests.sh' diff --git a/gcp_variant_transforms/libs/bigquery_util.py b/gcp_variant_transforms/libs/bigquery_util.py index c443d7baa..eb3256799 100644 --- a/gcp_variant_transforms/libs/bigquery_util.py +++ b/gcp_variant_transforms/libs/bigquery_util.py @@ -14,6 +14,7 @@ """Constants and simple utility functions related to BigQuery.""" +from concurrent.futures import TimeoutError import enum import exceptions import logging @@ -45,7 +46,6 @@ _TOTAL_BASE_PAIRS_SIG_DIGITS = 4 _PARTITION_SIZE_SIG_DIGITS = 1 -START_POSITION_COLUMN = 'start_position' _BQ_CREATE_PARTITIONED_TABLE_COMMAND = ( 'bq mk --table --range_partitioning=' '{PARTITION_COLUMN},0,{RANGE_END},{RANGE_INTERVAL} ' @@ -54,10 +54,26 @@ _BQ_CREATE_SAMPLE_INFO_TABLE_COMMAND = ( 'bq mk --table {FULL_TABLE_ID} {SCHEMA_FILE_PATH}') _BQ_DELETE_TABLE_COMMAND = 'bq rm -f -t {FULL_TABLE_ID}' +_BQ_EXTRACT_SCHEMA_COMMAND = ( + 'bq show --schema --format=prettyjson {FULL_TABLE_ID} > {SCHEMA_FILE_PATH}') _GCS_DELETE_FILES_COMMAND = 'gsutil -m rm -f -R {ROOT_PATH}' -_BQ_LOAD_JOB_NUM_RETRIES = 5 +_BQ_NUM_RETRIES = 3 _MAX_NUM_CONCURRENT_BQ_LOAD_JOBS = 4 +_GET_COLUMN_NAMES_QUERY = ( + 'SELECT column_name ' + 'FROM `{PROJECT_ID}`.{DATASET_ID}.INFORMATION_SCHEMA.COLUMNS ' + 'WHERE table_name = "{TABLE_ID}"') +_GET_CALL_SUB_FIELDS_QUERY = ( + 'SELECT field_path ' + 'FROM `{PROJECT_ID}`.{DATASET_ID}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS ' + 'WHERE table_name = "{TABLE_ID}" AND column_name="{CALL_COLUMN}"') +_MAIN_TABLE_ALIAS = 'main_table' +_CALL_TABLE_ALIAS = 'call_table' +_FLATTEN_CALL_QUERY = ( + 'SELECT {SELECT_COLUMNS} ' + 'FROM `{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}` as {MAIN_TABLE_ALIAS}, ' + 'UNNEST({CALL_COLUMN}) as {CALL_TABLE_ALIAS}') class ColumnKeyConstants(object): """Constants for column names in the BigQuery schema.""" @@ -75,6 +91,9 @@ class ColumnKeyConstants(object): CALLS_GENOTYPE = 'genotype' CALLS_PHASESET = 'phaseset' +CALL_SAMPLE_ID_COLUMN = (ColumnKeyConstants.CALLS + '_' + + ColumnKeyConstants.CALLS_SAMPLE_ID) + class TableFieldConstants(object): """Constants for field modes/types in the BigQuery schema.""" @@ -435,7 +454,7 @@ def _cancel_all_running_load_jobs(self): load_job.cancel() def _handle_failed_load_job(self, suffix, load_job): - if self._num_load_jobs_retries < _BQ_LOAD_JOB_NUM_RETRIES: + if self._num_load_jobs_retries < _BQ_NUM_RETRIES: self._num_load_jobs_retries += 1 # Retry the failed job after 5 minutes wait. time.sleep(300) @@ -482,6 +501,156 @@ def create_sample_info_table(output_table_id): SCHEMA_FILE_PATH=SAMPLE_INFO_TABLE_SCHEMA_FILE_PATH) _run_table_creation_command(bq_command) +class FlattenCallColumn(object): + def __init__(self, base_table_id, suffixes): + (self._project_id, + self._dataset_id, + self._base_table) = parse_table_reference(base_table_id) + assert suffixes + self._suffixes = suffixes[:] + + # We can use any of the input tables as source of schema, we use index 0 + self._schema_table_id = compose_table_name(self._base_table, + suffixes[0]) + self._column_names = [] + self._sub_fields = [] + self._client = bigquery.Client(project=self._project_id) + + def _run_query(self, query): + query_job = self._client.query(query) + num_retries = 0 + while True: + try: + iterator = query_job.result(timeout=300) + except TimeoutError as e: + logging.warning('Time out waiting for query: %s', query) + if num_retries < _BQ_NUM_RETRIES: + num_retries += 1 + time.sleep(90) + else: + raise e + else: + break + result = [] + for i in iterator: + result.append(str(i.values()[0])) + return result + + def _get_column_names(self): + if not self._column_names: + query = _GET_COLUMN_NAMES_QUERY.format(PROJECT_ID=self._project_id, + DATASET_ID=self._dataset_id, + TABLE_ID=self._schema_table_id) + self._column_names = self._run_query(query)[:] + assert self._column_names + return self._column_names + + def _get_call_sub_fields(self): + if not self._sub_fields: + query = _GET_CALL_SUB_FIELDS_QUERY.format( + PROJECT_ID=self._project_id, DATASET_ID=self._dataset_id, + TABLE_ID=self._schema_table_id, CALL_COLUMN=ColumnKeyConstants.CALLS) + # returned list is [call, call.name, call.genotype, call.phaseset, ...] + result = self._run_query(query)[1:] # Drop the first element + self._sub_fields = [sub_field.split('.')[1] for sub_field in result] + assert self._sub_fields + return self._sub_fields + + def _get_flatten_column_names(self): + column_names = self._get_column_names() + sub_fields = self._get_call_sub_fields() + select_list = [] + for column in column_names: + if column != ColumnKeyConstants.CALLS: + select_list.append(_MAIN_TABLE_ALIAS + '.' + column + ' AS `'+ + column + '`') + else: + for s_f in sub_fields: + select_list.append(_CALL_TABLE_ALIAS + '.' + s_f + ' AS `' + + ColumnKeyConstants.CALLS + '_' + s_f + '`') + return ', '.join(select_list) + + def _copy_to_flatten_table(self, output_table_id, cp_query): + job_config = bigquery.QueryJobConfig(destination=output_table_id) + query_job = self._client.query(cp_query, job_config=job_config) + num_retries = 0 + while True: + try: + _ = query_job.result(timeout=600) + except TimeoutError as e: + logging.warning('Time out waiting for query: %s', cp_query) + if num_retries < _BQ_NUM_RETRIES: + num_retries += 1 + time.sleep(90) + else: + logging.error('Copy to table query failed: %s', output_table_id) + raise e + else: + break + logging.info('Copy to table query was successful: %s', output_table_id) + + def _create_temp_flatten_table(self): + temp_suffix = time.strftime('%Y%m%d_%H%M%S') + temp_table_id = '{}{}'.format(self._schema_table_id, temp_suffix) + full_output_table_id = '{}.{}.{}'.format( + self._project_id, self._dataset_id, temp_table_id) + + select_columns = self._get_flatten_column_names() + cp_query = _FLATTEN_CALL_QUERY.format(SELECT_COLUMNS=select_columns, + PROJECT_ID=self._project_id, + DATASET_ID=self._dataset_id, + TABLE_ID=self._schema_table_id, + MAIN_TABLE_ALIAS=_MAIN_TABLE_ALIAS, + CALL_COLUMN=ColumnKeyConstants.CALLS, + CALL_TABLE_ALIAS=_CALL_TABLE_ALIAS) + cp_query += ' LIMIT 1' # We need this table only to extract its schema. + self._copy_to_flatten_table(full_output_table_id, cp_query) + logging.info('A new table with 1 row was crated: %s', full_output_table_id) + logging.info('This table is used to extract the schema of flatten table.') + return temp_table_id + + def get_flatten_table_schema(self, schema_file_path): + temp_table_id = self._create_temp_flatten_table() + full_table_id = '{}:{}.{}'.format( + self._project_id, self._dataset_id, temp_table_id) + bq_command = _BQ_EXTRACT_SCHEMA_COMMAND.format( + FULL_TABLE_ID=full_table_id, + SCHEMA_FILE_PATH=schema_file_path) + result = os.system(bq_command) + if result != 0: + logging.error('Failed to extract flatten table schema using "%s" command', + bq_command) + else: + logging.info('Successfully extracted the schema of flatten table.') + if _delete_table(full_table_id) == 0: + logging.info('Successfully deleted temporary table: %s', full_table_id) + else: + logging.error('Was not able to delete temporary table: %s', full_table_id) + return result + + def copy_to_flatten_table(self, output_base_table_id): + # Here we assume all output_table_base + suffices[:] are already created. + (output_project_id, + output_dataset_id, + output_base_table) = parse_table_reference(output_base_table_id) + select_columns = self._get_flatten_column_names() + for suffix in self._suffixes: + input_table_id = compose_table_name(self._base_table, suffix) + output_table_id = compose_table_name(output_base_table, suffix) + + full_output_table_id = '{}.{}.{}'.format( + output_project_id, output_dataset_id, output_table_id) + cp_query = _FLATTEN_CALL_QUERY.format( + SELECT_COLUMNS=select_columns, PROJECT_ID=self._project_id, + DATASET_ID=self._dataset_id, TABLE_ID=input_table_id, + MAIN_TABLE_ALIAS=_MAIN_TABLE_ALIAS, + CALL_COLUMN=ColumnKeyConstants.CALLS, + CALL_TABLE_ALIAS=_CALL_TABLE_ALIAS) + + self._copy_to_flatten_table(full_output_table_id, cp_query) + logging.info('Flatten table is fully loaded: %s', full_output_table_id) + + def create_output_table(full_table_id, # type: str partition_column, # type: str range_end, # type: int diff --git a/gcp_variant_transforms/libs/bigquery_util_test.py b/gcp_variant_transforms/libs/bigquery_util_test.py index 9f429474d..f9c98f136 100644 --- a/gcp_variant_transforms/libs/bigquery_util_test.py +++ b/gcp_variant_transforms/libs/bigquery_util_test.py @@ -417,7 +417,6 @@ def test_merge_field_schemas_merge_inner_record_fields(self): field_schemas_2), expected_merged_field_schemas) - def test_does_table_exist(self): client = mock.Mock() client.tables.Get.return_value = bigquery.Table( @@ -516,3 +515,48 @@ def test_calculate_optimal_range_interval_large(self): bigquery_util.calculate_optimal_range_interval(large_range_end)) self.assertEqual(expected_interval, range_interval) self.assertEqual(expected_end, range_end_enlarged) + + +class FlattenCallColumnTest(unittest.TestCase): + """Test cases for class `FlattenCallColumn`.""" + + def setUp(self): + input_base_table = ('gcp-variant-transforms-test:' + 'bq_to_vcf_integration_tests.' + 'merge_option_move_to_calls') + self._flatter = bigquery_util.FlattenCallColumn(input_base_table, ['chr20']) + + def test_get_column_names(self): + expected_column_names = ['reference_name', 'start_position', 'end_position', + 'reference_bases', 'alternate_bases', 'names', + 'quality', 'filter', 'call', 'NS', 'DP', 'AA', + 'DB', 'H2'] + self.assertEqual(expected_column_names, self._flatter._get_column_names()) + + def test_get_call_sub_fields(self): + expected_sub_fields = \ + ['sample_id', 'genotype', 'phaseset', 'DP', 'GQ', 'HQ'] + self.assertEqual(expected_sub_fields, self._flatter._get_call_sub_fields()) + + def test_get_flatten_column_names(self): + expected_select = ( + 'main_table.reference_name AS `reference_name`, ' + 'main_table.start_position AS `start_position`, ' + 'main_table.end_position AS `end_position`, ' + 'main_table.reference_bases AS `reference_bases`, ' + 'main_table.alternate_bases AS `alternate_bases`, ' + 'main_table.names AS `names`, ' + 'main_table.quality AS `quality`, ' + 'main_table.filter AS `filter`, ' + 'call_table.sample_id AS `call_sample_id`, ' + 'call_table.genotype AS `call_genotype`, ' + 'call_table.phaseset AS `call_phaseset`, ' + 'call_table.DP AS `call_DP`, ' + 'call_table.GQ AS `call_GQ`, ' + 'call_table.HQ AS `call_HQ`, ' + 'main_table.NS AS `NS`, ' + 'main_table.DP AS `DP`, ' + 'main_table.AA AS `AA`, ' + 'main_table.DB AS `DB`, ' + 'main_table.H2 AS `H2`') + self.assertEqual(expected_select, self._flatter._get_flatten_column_names()) diff --git a/gcp_variant_transforms/options/variant_transform_options.py b/gcp_variant_transforms/options/variant_transform_options.py index e1a7c8222..e0c8c495d 100644 --- a/gcp_variant_transforms/options/variant_transform_options.py +++ b/gcp_variant_transforms/options/variant_transform_options.py @@ -124,9 +124,27 @@ class BigQueryWriteOptions(VariantTransformsOptions): def add_arguments(self, parser): # type: (argparse.ArgumentParser) -> None - parser.add_argument('--output_table', - default='', - help='BigQuery table to store the results.') + parser.add_argument( + '--output_table', + default='', + help=('Base name of the BigQuery tables which will store the results. ' + 'Note that sharded tables will be named as following: ' + ' * `output_table`__chr1 ' + ' * `output_table`__chr2 ' + ' * ... ' + ' * `output_table`__residual ' + 'where "chr1", "chr2", ..., and "residual" suffixes correspond ' + 'to the value of `table_name_suffix` in the sharding config file ' + '(see --sharding_config_path).')) + + parser.add_argument( + '--sample_lookup_optimized_output_table', + default='', + help=('In addition to the default output tables (which are optimized ' + 'for variant look up queries), you can store a second copy of ' + 'your data in BigQuery tables that are optimized for sample ' + 'look up queries. Note that setting this option will double your ' + 'BigQuery storage costs.')) parser.add_argument( '--output_avro_path', default='', @@ -206,54 +224,66 @@ def validate(self, parsed_args, client=None): not parsed_args.sharding_config_path.strip()): raise ValueError( '--sharding_config_path must point to a valid config file.') - # Ensuring (not) existence of output tables is aligned with --append value. - if parsed_args.output_table: - if (parsed_args.output_table != - bigquery_util.get_table_base_name(parsed_args.output_table)): - raise ValueError(('Output table cannot contain "{}" we reserve this ' + + if not client: + credentials = GoogleCredentials.get_application_default().create_scoped( + ['https://www.googleapis.com/auth/bigquery']) + client = bigquery.BigqueryV2(credentials=credentials) + if not parsed_args.output_table: + raise ValueError('--output_table must have a value.') + self._validate_output_tables( + client, parsed_args.output_table, + parsed_args.sharding_config_path, parsed_args.append) + + if parsed_args.sample_lookup_optimized_output_table: + self._validate_output_tables( + client, parsed_args.sample_lookup_optimized_output_table, + parsed_args.sharding_config_path, parsed_args.append) + + def _validate_output_tables(self, client, + output_table_base_name, + sharding_config_path, append): + if (output_table_base_name != + bigquery_util.get_table_base_name(output_table_base_name)): + raise ValueError(('Output table cannot contain "{}" we reserve this ' + 'string to mark sharded output tables.').format( + bigquery_util.TABLE_SUFFIX_SEPARATOR)) + + project_id, dataset_id, table_id = bigquery_util.parse_table_reference( + output_table_base_name) + bigquery_util.raise_error_if_dataset_not_exists(client, project_id, + dataset_id) + all_output_tables = [] + all_output_tables.append( + bigquery_util.compose_table_name(table_id, SAMPLE_INFO_TABLE_SUFFIX)) + sharding = variant_sharding.VariantSharding(sharding_config_path) + num_shards = sharding.get_num_shards() + # In case there is no residual in config we will ignore the last shard. + if not sharding.should_keep_shard(sharding.get_residual_index()): + num_shards -= 1 + for i in range(num_shards): + table_suffix = sharding.get_output_table_suffix(i) + if table_suffix != bigquery_util.get_table_base_name(table_suffix): + raise ValueError(('Table suffix cannot contain "{}" we reserve this ' 'string to mark sharded output tables.').format( bigquery_util.TABLE_SUFFIX_SEPARATOR)) - if not client: - credentials = GoogleCredentials.get_application_default().create_scoped( - ['https://www.googleapis.com/auth/bigquery']) - client = bigquery.BigqueryV2(credentials=credentials) - - project_id, dataset_id, table_id = bigquery_util.parse_table_reference( - parsed_args.output_table) - bigquery_util.raise_error_if_dataset_not_exists(client, project_id, - dataset_id) - all_output_tables = [] - all_output_tables.append( - bigquery_util.compose_table_name(table_id, SAMPLE_INFO_TABLE_SUFFIX)) - sharding = variant_sharding.VariantSharding( - parsed_args.sharding_config_path) - num_shards = sharding.get_num_shards() - # In case there is no residual in config we will ignore the last shard. - if not sharding.should_keep_shard(sharding.get_residual_index()): - num_shards -= 1 - for i in range(num_shards): - table_suffix = sharding.get_output_table_suffix(i) - if table_suffix != bigquery_util.get_table_base_name(table_suffix): - raise ValueError(('Table suffix cannot contain "{}" we reserve this ' - 'string to mark sharded output tables.').format( - bigquery_util.TABLE_SUFFIX_SEPARATOR)) - all_output_tables.append(bigquery_util.compose_table_name(table_id, - table_suffix)) - - for output_table in all_output_tables: - if parsed_args.append: - if not bigquery_util.table_exist(client, project_id, - dataset_id, output_table): - raise ValueError( - 'Table {}:{}.{} does not exist, cannot append to it.'.format( - project_id, dataset_id, output_table)) - else: - if bigquery_util.table_exist(client, project_id, - dataset_id, output_table): - raise ValueError( - ('Table {}:{}.{} already exists, cannot overwrite it. Please ' - 'set `--append True` if you want to append to it.').format( - project_id, dataset_id, output_table)) + all_output_tables.append(bigquery_util.compose_table_name(table_id, + table_suffix)) + + for output_table in all_output_tables: + if append: + if not bigquery_util.table_exist(client, project_id, + dataset_id, output_table): + raise ValueError( + 'Table {}:{}.{} does not exist, cannot append to it.'.format( + project_id, dataset_id, output_table)) + else: + if bigquery_util.table_exist(client, project_id, + dataset_id, output_table): + raise ValueError( + ('Table {}:{}.{} already exists, cannot overwrite it. Please ' + 'set `--append True` if you want to append to it.').format( + project_id, dataset_id, output_table)) class AnnotationOptions(VariantTransformsOptions): diff --git a/gcp_variant_transforms/testing/data/schema/flatten_merge_option_move_to_calls.json b/gcp_variant_transforms/testing/data/schema/flatten_merge_option_move_to_calls.json new file mode 100644 index 000000000..15443f8f9 --- /dev/null +++ b/gcp_variant_transforms/testing/data/schema/flatten_merge_option_move_to_calls.json @@ -0,0 +1,109 @@ +[ + { + "mode": "NULLABLE", + "name": "reference_name", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "start_position", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "end_position", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "reference_bases", + "type": "STRING" + }, + { + "fields": [ + { + "mode": "NULLABLE", + "name": "alt", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "AF", + "type": "FLOAT" + } + ], + "mode": "REPEATED", + "name": "alternate_bases", + "type": "RECORD" + }, + { + "mode": "REPEATED", + "name": "names", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "quality", + "type": "FLOAT" + }, + { + "mode": "REPEATED", + "name": "filter", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "call_sample_id", + "type": "INTEGER" + }, + { + "mode": "REPEATED", + "name": "call_genotype", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "call_phaseset", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "call_DP", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "call_GQ", + "type": "INTEGER" + }, + { + "mode": "REPEATED", + "name": "call_HQ", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "NS", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "DP", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "AA", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "DB", + "type": "BOOLEAN" + }, + { + "mode": "NULLABLE", + "name": "H2", + "type": "BOOLEAN" + } +] diff --git a/gcp_variant_transforms/vcf_to_bq.py b/gcp_variant_transforms/vcf_to_bq.py index c1a476e33..7662534b1 100644 --- a/gcp_variant_transforms/vcf_to_bq.py +++ b/gcp_variant_transforms/vcf_to_bq.py @@ -558,6 +558,33 @@ def run(argv=None): 'failed.', avro_root_path) + if known_args.sample_lookup_optimized_output_table: + flatten_call_column = bigquery_util.FlattenCallColumn( + known_args.output_table, suffixes) + try: + flatten_schema_file = tempfile.mkstemp(suffix=_BQ_SCHEMA_FILE_SUFFIX)[1] + if flatten_call_column.get_flatten_table_schema(flatten_schema_file) != 0: + raise ValueError('Failed to extract schema of flatten table') + # Create output flatten tables if needed + if not known_args.append: + for suffix in suffixes: + output_table_id = bigquery_util.compose_table_name( + known_args.sample_lookup_optimized_output_table, suffix) + bigquery_util.create_output_table(output_table_id, + bigquery_util.CALL_SAMPLE_ID_COLUMN, + bigquery_util.MAX_RANGE_END, + flatten_schema_file) + logging.info('Sample lookup optimized table %s was created.', + output_table_id) + # Copy to flatten sample lookup tables from the variant lookup tables. + flatten_call_column.copy_to_flatten_table( + known_args.sample_lookup_optimized_output_table) + logging.info('All sample lookup optimized tables are fully loaded.') + except Exception as e: + logging.error('Something unexpected happened during the loading rows to ' + 'sample optimized table stage: %s', str(e)) + raise e + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) run()