diff --git a/gcp_variant_transforms/libs/bigquery_util.py b/gcp_variant_transforms/libs/bigquery_util.py index 480458392..055b95580 100644 --- a/gcp_variant_transforms/libs/bigquery_util.py +++ b/gcp_variant_transforms/libs/bigquery_util.py @@ -14,6 +14,7 @@ """Constants and simple utility functions related to BigQuery.""" +from concurrent.futures import TimeoutError import enum import exceptions import logging @@ -45,7 +46,6 @@ _TOTAL_BASE_PAIRS_SIG_DIGITS = 4 _PARTITION_SIZE_SIG_DIGITS = 1 -START_POSITION_COLUMN = 'start_position' _BQ_CREATE_PARTITIONED_TABLE_COMMAND = ( 'bq mk --table --range_partitioning=' '{PARTITION_COLUMN},0,{RANGE_END},{RANGE_INTERVAL} ' @@ -54,10 +54,28 @@ _BQ_CREATE_SAMPLE_INFO_TABLE_COMMAND = ( 'bq mk --table {FULL_TABLE_ID} {SCHEMA_FILE_PATH}') _BQ_DELETE_TABLE_COMMAND = 'bq rm -f -t {FULL_TABLE_ID}' +_BQ_EXTRACT_SCHEMA_COMMAND = ( + 'bq show --schema --format=prettyjson {FULL_TABLE_ID} > {SCHEMA_FILE_PATH}') _GCS_DELETE_FILES_COMMAND = 'gsutil -m rm -f -R {ROOT_PATH}' _BQ_NUM_RETRIES = 5 _MAX_NUM_CONCURRENT_BQ_LOAD_JOBS = 4 +_GET_COLUMN_NAMES_QUERY = ( + 'SELECT column_name ' + 'FROM `{PROJECT_ID}`.{DATASET_ID}.INFORMATION_SCHEMA.COLUMNS ' + 'WHERE table_name = "{TABLE_ID}"') +_GET_CALL_SUB_FIELDS_QUERY = ( + 'SELECT field_path ' + 'FROM `{PROJECT_ID}`.{DATASET_ID}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS ' + 'WHERE table_name = "{TABLE_ID}" AND column_name="{CALL_COLUMN}"') + +_MAIN_TABLE_ALIAS = 'main_table' +_CALL_TABLE_ALIAS = 'call_table' +_COLUMN_AS = '{TABLE_ALIAS}.{COL} AS `{COL_NAME}`' +_FLATTEN_CALL_QUERY = ( + 'SELECT {SELECT_COLUMNS} ' + 'FROM `{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}` as {MAIN_TABLE_ALIAS}, ' + 'UNNEST({CALL_COLUMN}) as {CALL_TABLE_ALIAS}') class ColumnKeyConstants(object): """Constants for column names in the BigQuery schema.""" @@ -75,6 +93,9 @@ class ColumnKeyConstants(object): CALLS_GENOTYPE = 'genotype' CALLS_PHASESET = 'phaseset' +CALL_SAMPLE_ID_COLUMN = (ColumnKeyConstants.CALLS + '_' + + ColumnKeyConstants.CALLS_SAMPLE_ID) + class TableFieldConstants(object): """Constants for field modes/types in the BigQuery schema.""" @@ -508,6 +529,201 @@ def create_sample_info_table(output_table_id): SCHEMA_FILE_PATH=SAMPLE_INFO_TABLE_SCHEMA_FILE_PATH) _run_table_creation_command(bq_command) +class FlattenCallColumn(object): + """Flattens call column to convert variant opt tables to sample opt tables.""" + + def __init__(self, base_table_id, suffixes): + # type (str, List[str]) -> None + """Initialize `FlattenCallColumn` object. + + In preparation to convert variant lookup optimized tables to sample lookup + optimized tables, we initiate this class with the base table name of variant + opt table (set using --output_table flag) and the list of suffixes (which + are extracted from sharding config file). + + Args: + base_table_id: Base name of variant opt outputs (set by --output_table). + suffixes: List of suffixes (extracted from sharding config file). + """ + (self._project_id, + self._dataset_id, + self._base_table) = parse_table_reference(base_table_id) + assert suffixes + self._suffixes = suffixes[:] + + # We can use any of the input tables as source of schema, we use index 0 + self._schema_table_id = compose_table_name(self._base_table, + suffixes[0]) + self._column_names = [] + self._sub_fields = [] + self._client = bigquery.Client(project=self._project_id) + + def _run_query(self, query): + query_job = self._client.query(query) + num_retries = 0 + while True: + try: + iterator = query_job.result(timeout=300) + except TimeoutError as e: + logging.warning('Time out waiting for query: %s', query) + if num_retries < _BQ_NUM_RETRIES: + num_retries += 1 + time.sleep(90) + else: + raise e + else: + break + result = [] + for i in iterator: + result.append(str(i.values()[0])) + return result + + def _get_column_names(self): + if not self._column_names: + query = _GET_COLUMN_NAMES_QUERY.format(PROJECT_ID=self._project_id, + DATASET_ID=self._dataset_id, + TABLE_ID=self._schema_table_id) + self._column_names = self._run_query(query)[:] + assert self._column_names + return self._column_names + + def _get_call_sub_fields(self): + if not self._sub_fields: + query = _GET_CALL_SUB_FIELDS_QUERY.format( + PROJECT_ID=self._project_id, DATASET_ID=self._dataset_id, + TABLE_ID=self._schema_table_id, CALL_COLUMN=ColumnKeyConstants.CALLS) + # returned list is [call, call.name, call.genotype, call.phaseset, ...] + result = self._run_query(query)[1:] # Drop the first element + self._sub_fields = [sub_field.split('.')[1] for sub_field in result] + assert self._sub_fields + return self._sub_fields + + def _get_flatten_column_names(self): + column_names = self._get_column_names() + sub_fields = self._get_call_sub_fields() + select_list = [] + for column in column_names: + if column != ColumnKeyConstants.CALLS: + select_list.append( + _COLUMN_AS.format(TABLE_ALIAS=_MAIN_TABLE_ALIAS, COL=column, + COL_NAME=column)) + else: + for s_f in sub_fields: + select_list.append( + _COLUMN_AS.format(TABLE_ALIAS=_CALL_TABLE_ALIAS, COL=s_f, + COL_NAME=ColumnKeyConstants.CALLS + '_' + s_f)) + return ', '.join(select_list) + + def _copy_to_flatten_table(self, output_table_id, cp_query): + job_config = bigquery.QueryJobConfig(destination=output_table_id) + query_job = self._client.query(cp_query, job_config=job_config) + num_retries = 0 + while True: + try: + _ = query_job.result(timeout=600) + except TimeoutError as e: + logging.warning('Time out waiting for query: %s', cp_query) + if num_retries < _BQ_NUM_RETRIES: + num_retries += 1 + time.sleep(90) + else: + logging.error('Copy to table query failed: %s', output_table_id) + raise e + else: + break + logging.info('Copy to table query was successful: %s', output_table_id) + + def _create_temp_flatten_table_with_1_row(self): + temp_suffix = time.strftime('%Y%m%d_%H%M%S') + temp_table_id = '{}{}'.format(self._schema_table_id, temp_suffix) + full_output_table_id = '{}.{}.{}'.format( + self._project_id, self._dataset_id, temp_table_id) + + select_columns = self._get_flatten_column_names() + cp_query = _FLATTEN_CALL_QUERY.format(SELECT_COLUMNS=select_columns, + PROJECT_ID=self._project_id, + DATASET_ID=self._dataset_id, + TABLE_ID=self._schema_table_id, + MAIN_TABLE_ALIAS=_MAIN_TABLE_ALIAS, + CALL_COLUMN=ColumnKeyConstants.CALLS, + CALL_TABLE_ALIAS=_CALL_TABLE_ALIAS) + cp_query += ' LIMIT 1' # We need this table only to extract its schema. + self._copy_to_flatten_table(full_output_table_id, cp_query) + logging.info('A new table with 1 row was created: %s', full_output_table_id) + logging.info('This table is used to extract the schema of flatten table.') + return temp_table_id + + def get_flatten_table_schema(self, schema_file_path): + # type: (str) -> bool + """Write the flatten table's schema to the given json file. + + This method basically performs the following tasks: + * Composes a 'flattening query' based on _schema_table_id table's schema. + * Runs the 'flattening query' to read 1 row and writes it to a temp table. + * Extracts the schema of the temp table using _BQ_EXTRACT_SCHEMA_COMMAND. + + Args: + schema_file_path: The json schema will be written to this file. + + Returns; + A bool value indicating if the schema was successfully extracted. + """ + temp_table_id = self._create_temp_flatten_table_with_1_row() + full_table_id = '{}:{}.{}'.format( + self._project_id, self._dataset_id, temp_table_id) + bq_command = _BQ_EXTRACT_SCHEMA_COMMAND.format( + FULL_TABLE_ID=full_table_id, + SCHEMA_FILE_PATH=schema_file_path) + result = os.system(bq_command) + if result != 0: + logging.error('Failed to extract flatten table schema using "%s" command', + bq_command) + else: + logging.info('Successfully extracted the schema of flatten table.') + if delete_table(full_table_id) == 0: + logging.info('Successfully deleted temporary table: %s', full_table_id) + else: + logging.error('Was not able to delete temporary table: %s', full_table_id) + return result == 0 + + def copy_to_flatten_table(self, output_base_table_id): + # type: (str) -> None + """Copies data from variant lookup optimized tables to sample lookup tables. + + Copies rows from _base_table_id__* to output_base_table_id__* for each value + in _suffixes. Here we assume destination tables are already created and are + partitioned based on call_sample_id column. The copying process is done via + a flattening query similar to the one used in get_flatten_table_schema(). + + Note that if source tables have repeated sample_ids then output table will + have more rows than input table. Essentially: + Number of output rows = Number of input rows * Number of repeated sample_ids + + Args: + output_base_table_id: Base table name of output tables. + """ + # Here we assume all output_table_base + suffices[:] are already created. + (output_project_id, + output_dataset_id, + output_base_table) = parse_table_reference(output_base_table_id) + select_columns = self._get_flatten_column_names() + for suffix in self._suffixes: + input_table_id = compose_table_name(self._base_table, suffix) + output_table_id = compose_table_name(output_base_table, suffix) + + full_output_table_id = '{}.{}.{}'.format( + output_project_id, output_dataset_id, output_table_id) + cp_query = _FLATTEN_CALL_QUERY.format( + SELECT_COLUMNS=select_columns, PROJECT_ID=self._project_id, + DATASET_ID=self._dataset_id, TABLE_ID=input_table_id, + MAIN_TABLE_ALIAS=_MAIN_TABLE_ALIAS, + CALL_COLUMN=ColumnKeyConstants.CALLS, + CALL_TABLE_ALIAS=_CALL_TABLE_ALIAS) + + self._copy_to_flatten_table(full_output_table_id, cp_query) + logging.info('Flatten table is fully loaded: %s', full_output_table_id) + + def create_output_table(full_table_id, # type: str partition_column, # type: str range_end, # type: int diff --git a/gcp_variant_transforms/libs/bigquery_util_test.py b/gcp_variant_transforms/libs/bigquery_util_test.py index f598d2175..ceda9730c 100644 --- a/gcp_variant_transforms/libs/bigquery_util_test.py +++ b/gcp_variant_transforms/libs/bigquery_util_test.py @@ -516,3 +516,50 @@ def test_calculate_optimal_range_interval_large(self): bigquery_util.calculate_optimal_range_interval(large_range_end)) self.assertEqual(expected_interval, range_interval) self.assertEqual(expected_end, range_end_enlarged) + + +class FlattenCallColumnTest(unittest.TestCase): + """Test cases for class `FlattenCallColumn`.""" + + def setUp(self): + # We never query this table for running the following test, however, the + # mock values are based on this table's schema. In other words: + # mock_columns.return_value = self._flatter._get_column_names() + # mock_sub_fields.return_value = self._flatter._get_call_sub_fields() + input_base_table = ('gcp-variant-transforms-test:' + 'bq_to_vcf_integration_tests.' + 'merge_option_move_to_calls') + self._flatter = bigquery_util.FlattenCallColumn(input_base_table, ['chr20']) + + @mock.patch('gcp_variant_transforms.libs.bigquery_util_test.bigquery_util.' + 'FlattenCallColumn._get_column_names') + @mock.patch('gcp_variant_transforms.libs.bigquery_util_test.bigquery_util.' + 'FlattenCallColumn._get_call_sub_fields') + def test_get_flatten_column_names(self, mock_sub_fields, mock_columns): + mock_columns.return_value = ( + ['reference_name', 'start_position', 'end_position', 'reference_bases', + 'alternate_bases', 'names', 'quality', 'filter', 'call', 'NS', 'DP', + 'AA', 'DB', 'H2']) + mock_sub_fields.return_value = ( + ['sample_id', 'genotype', 'phaseset', 'DP', 'GQ', 'HQ']) + expected_select = ( + 'main_table.reference_name AS `reference_name`, ' + 'main_table.start_position AS `start_position`, ' + 'main_table.end_position AS `end_position`, ' + 'main_table.reference_bases AS `reference_bases`, ' + 'main_table.alternate_bases AS `alternate_bases`, ' + 'main_table.names AS `names`, ' + 'main_table.quality AS `quality`, ' + 'main_table.filter AS `filter`, ' + 'call_table.sample_id AS `call_sample_id`, ' + 'call_table.genotype AS `call_genotype`, ' + 'call_table.phaseset AS `call_phaseset`, ' + 'call_table.DP AS `call_DP`, ' + 'call_table.GQ AS `call_GQ`, ' + 'call_table.HQ AS `call_HQ`, ' + 'main_table.NS AS `NS`, ' + 'main_table.DP AS `DP`, ' + 'main_table.AA AS `AA`, ' + 'main_table.DB AS `DB`, ' + 'main_table.H2 AS `H2`') + self.assertEqual(expected_select, self._flatter._get_flatten_column_names()) diff --git a/gcp_variant_transforms/options/variant_transform_options.py b/gcp_variant_transforms/options/variant_transform_options.py index e1a7c8222..b60929e71 100644 --- a/gcp_variant_transforms/options/variant_transform_options.py +++ b/gcp_variant_transforms/options/variant_transform_options.py @@ -124,9 +124,31 @@ class BigQueryWriteOptions(VariantTransformsOptions): def add_arguments(self, parser): # type: (argparse.ArgumentParser) -> None - parser.add_argument('--output_table', - default='', - help='BigQuery table to store the results.') + parser.add_argument( + '--output_table', + default='', + help=('Base name of the BigQuery tables which will store the results. ' + 'Note that sharded tables will be named as following: ' + ' * `output_table`__chr1 ' + ' * `output_table`__chr2 ' + ' * ... ' + ' * `output_table`__residual ' + 'where "chr1", "chr2", ..., and "residual" suffixes correspond ' + 'to the value of `table_name_suffix` in the sharding config file ' + '(see --sharding_config_path).')) + + parser.add_argument( + '--sample_lookup_optimized_output_table', + default='', + help=('In addition to the default output tables (which are optimized ' + 'for variant lookup queries), you can store a second copy of ' + 'your data in BigQuery tables that are optimized for sample ' + 'lookup queries using this flag.' + 'Note that setting this flag will *at least* double your ' + 'BigQuery storage costs. If your input VCF files are joint ' + 'genotyped (say with n sample) then sample lookup tables will ' + 'have n * the number of rows of their corresponding variant ' + 'lookup table.')) parser.add_argument( '--output_avro_path', default='', @@ -206,54 +228,70 @@ def validate(self, parsed_args, client=None): not parsed_args.sharding_config_path.strip()): raise ValueError( '--sharding_config_path must point to a valid config file.') - # Ensuring (not) existence of output tables is aligned with --append value. - if parsed_args.output_table: - if (parsed_args.output_table != - bigquery_util.get_table_base_name(parsed_args.output_table)): - raise ValueError(('Output table cannot contain "{}" we reserve this ' + + if not client: + credentials = GoogleCredentials.get_application_default().create_scoped( + ['https://www.googleapis.com/auth/bigquery']) + client = bigquery.BigqueryV2(credentials=credentials) + if not parsed_args.output_table: + raise ValueError('--output_table must have a value.') + self._validate_output_tables( + client, parsed_args.output_table, + parsed_args.sharding_config_path, parsed_args.append) + + if parsed_args.sample_lookup_optimized_output_table: + if (parsed_args.output_table == + parsed_args.sample_lookup_optimized_output_table): + raise ValueError('sample_lookup_optimized_output_table cannot be the ' + 'same as output_table.') + self._validate_output_tables( + client, parsed_args.sample_lookup_optimized_output_table, + parsed_args.sharding_config_path, parsed_args.append) + + def _validate_output_tables(self, client, + output_table_base_name, + sharding_config_path, append): + if (output_table_base_name != + bigquery_util.get_table_base_name(output_table_base_name)): + raise ValueError(('Output table cannot contain "{}". we reserve this ' + 'string to mark sharded output tables.').format( + bigquery_util.TABLE_SUFFIX_SEPARATOR)) + + project_id, dataset_id, table_id = bigquery_util.parse_table_reference( + output_table_base_name) + bigquery_util.raise_error_if_dataset_not_exists(client, project_id, + dataset_id) + all_output_tables = [] + all_output_tables.append( + bigquery_util.compose_table_name(table_id, SAMPLE_INFO_TABLE_SUFFIX)) + sharding = variant_sharding.VariantSharding(sharding_config_path) + num_shards = sharding.get_num_shards() + # In case there is no residual in config we will ignore the last shard. + if not sharding.should_keep_shard(sharding.get_residual_index()): + num_shards -= 1 + for i in range(num_shards): + table_suffix = sharding.get_output_table_suffix(i) + if table_suffix != bigquery_util.get_table_base_name(table_suffix): + raise ValueError(('Table suffix cannot contain "{}" we reserve this ' 'string to mark sharded output tables.').format( bigquery_util.TABLE_SUFFIX_SEPARATOR)) - if not client: - credentials = GoogleCredentials.get_application_default().create_scoped( - ['https://www.googleapis.com/auth/bigquery']) - client = bigquery.BigqueryV2(credentials=credentials) - - project_id, dataset_id, table_id = bigquery_util.parse_table_reference( - parsed_args.output_table) - bigquery_util.raise_error_if_dataset_not_exists(client, project_id, - dataset_id) - all_output_tables = [] - all_output_tables.append( - bigquery_util.compose_table_name(table_id, SAMPLE_INFO_TABLE_SUFFIX)) - sharding = variant_sharding.VariantSharding( - parsed_args.sharding_config_path) - num_shards = sharding.get_num_shards() - # In case there is no residual in config we will ignore the last shard. - if not sharding.should_keep_shard(sharding.get_residual_index()): - num_shards -= 1 - for i in range(num_shards): - table_suffix = sharding.get_output_table_suffix(i) - if table_suffix != bigquery_util.get_table_base_name(table_suffix): - raise ValueError(('Table suffix cannot contain "{}" we reserve this ' - 'string to mark sharded output tables.').format( - bigquery_util.TABLE_SUFFIX_SEPARATOR)) - all_output_tables.append(bigquery_util.compose_table_name(table_id, - table_suffix)) - - for output_table in all_output_tables: - if parsed_args.append: - if not bigquery_util.table_exist(client, project_id, - dataset_id, output_table): - raise ValueError( - 'Table {}:{}.{} does not exist, cannot append to it.'.format( - project_id, dataset_id, output_table)) - else: - if bigquery_util.table_exist(client, project_id, - dataset_id, output_table): - raise ValueError( - ('Table {}:{}.{} already exists, cannot overwrite it. Please ' - 'set `--append True` if you want to append to it.').format( - project_id, dataset_id, output_table)) + all_output_tables.append(bigquery_util.compose_table_name(table_id, + table_suffix)) + + for output_table in all_output_tables: + if append: + if not bigquery_util.table_exist(client, project_id, + dataset_id, output_table): + raise ValueError( + 'Table {}:{}.{} does not exist, cannot append to it.'.format( + project_id, dataset_id, output_table)) + else: + if bigquery_util.table_exist(client, project_id, + dataset_id, output_table): + raise ValueError( + ('Table {}:{}.{} already exists, cannot overwrite it. Please ' + 'set `--append True` if you want to append to it.').format( + project_id, dataset_id, output_table)) class AnnotationOptions(VariantTransformsOptions): diff --git a/gcp_variant_transforms/testing/data/schema/flatten_merge_option_move_to_calls.json b/gcp_variant_transforms/testing/data/schema/flatten_merge_option_move_to_calls.json new file mode 100644 index 000000000..15443f8f9 --- /dev/null +++ b/gcp_variant_transforms/testing/data/schema/flatten_merge_option_move_to_calls.json @@ -0,0 +1,109 @@ +[ + { + "mode": "NULLABLE", + "name": "reference_name", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "start_position", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "end_position", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "reference_bases", + "type": "STRING" + }, + { + "fields": [ + { + "mode": "NULLABLE", + "name": "alt", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "AF", + "type": "FLOAT" + } + ], + "mode": "REPEATED", + "name": "alternate_bases", + "type": "RECORD" + }, + { + "mode": "REPEATED", + "name": "names", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "quality", + "type": "FLOAT" + }, + { + "mode": "REPEATED", + "name": "filter", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "call_sample_id", + "type": "INTEGER" + }, + { + "mode": "REPEATED", + "name": "call_genotype", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "call_phaseset", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "call_DP", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "call_GQ", + "type": "INTEGER" + }, + { + "mode": "REPEATED", + "name": "call_HQ", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "NS", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "DP", + "type": "INTEGER" + }, + { + "mode": "NULLABLE", + "name": "AA", + "type": "STRING" + }, + { + "mode": "NULLABLE", + "name": "DB", + "type": "BOOLEAN" + }, + { + "mode": "NULLABLE", + "name": "H2", + "type": "BOOLEAN" + } +] diff --git a/gcp_variant_transforms/vcf_to_bq.py b/gcp_variant_transforms/vcf_to_bq.py index ef993a900..970b2ac4a 100644 --- a/gcp_variant_transforms/vcf_to_bq.py +++ b/gcp_variant_transforms/vcf_to_bq.py @@ -536,7 +536,7 @@ def run(argv=None): load_avro = bigquery_util.LoadAvro(avro_root_path, known_args.output_table, suffixes, not known_args.append) - _ = load_avro.start_loading() + not_empty_variant_suffixes = load_avro.start_loading() except Exception as e: logging.error('Something unexpected happened during the loading of AVRO ' 'files to BigQuery: %s', str(e)) @@ -559,6 +559,33 @@ def run(argv=None): 'failed.', avro_root_path) + if known_args.sample_lookup_optimized_output_table: + flatten_call_column = bigquery_util.FlattenCallColumn( + known_args.output_table, not_empty_variant_suffixes) + try: + flatten_schema_file = tempfile.mkstemp(suffix=_BQ_SCHEMA_FILE_SUFFIX)[1] + if not flatten_call_column.get_flatten_table_schema(flatten_schema_file): + raise ValueError('Failed to extract schema of flatten table') + # Create output flatten tables if needed + if not known_args.append: + for suffix in not_empty_variant_suffixes: + output_table_id = bigquery_util.compose_table_name( + known_args.sample_lookup_optimized_output_table, suffix) + bigquery_util.create_output_table(output_table_id, + bigquery_util.CALL_SAMPLE_ID_COLUMN, + bigquery_util.MAX_RANGE_END, + flatten_schema_file) + logging.info('Sample lookup optimized table %s was created.', + output_table_id) + # Copy to flatten sample lookup tables from the variant lookup tables. + flatten_call_column.copy_to_flatten_table( + known_args.sample_lookup_optimized_output_table) + logging.info('All sample lookup optimized tables are fully loaded.') + except Exception as e: + logging.error('Something unexpected happened during the loading rows to ' + 'sample optimized table stage: %s', str(e)) + raise e + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) run()