Skip to content

Commit

Permalink
#299 - Sample list ease of use for cohort extracts (#7272)
Browse files Browse the repository at this point in the history
* work on issue 299
* PR comments
* update dockstore/image
  • Loading branch information
kcibul authored May 26, 2021
1 parent 13aa5cf commit f09214c
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 54 deletions.
7 changes: 3 additions & 4 deletions scripts/variantstore/wdl/GvsExtractCallset.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ workflow GvsExtractCallset {
File reference_index
File reference_dict

String fq_samples_to_extract_table
String fq_cohort_extract_table
String fq_cohort_extract_table_prefix
String query_project = data_project

String fq_filter_set_info_table = "~{data_project}.~{default_dataset}.filter_set_info"
Expand Down Expand Up @@ -51,9 +50,9 @@ workflow GvsExtractCallset {
reference = reference,
reference_index = reference_index,
reference_dict = reference_dict,
fq_samples_to_extract_table = fq_samples_to_extract_table,
fq_samples_to_extract_table = "~{fq_cohort_extract_table_prefix}__SAMPLES",
intervals = SplitIntervals.interval_files[i],
fq_cohort_extract_table = fq_cohort_extract_table,
fq_cohort_extract_table = "~{fq_cohort_extract_table_prefix}__DATA",
read_project_id = query_project,
do_not_filter_override = do_not_filter_override,
fq_filter_set_info_table = fq_filter_set_info_table,
Expand Down
18 changes: 9 additions & 9 deletions scripts/variantstore/wdl/GvsPrepareCallset.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ workflow GvsPrepareCallset {
input {
String data_project
String default_dataset
String destination_cohort_table_name
String destination_cohort_table_prefix
File sample_names_to_extract

# inputs with defaults
String query_project = data_project
Expand All @@ -13,7 +14,6 @@ workflow GvsPrepareCallset {
String destination_dataset = default_dataset

String fq_petvet_dataset = "~{data_project}.~{default_dataset}"
String fq_cohort_sample_table = "~{data_project}.~{default_dataset}.sample_info"
String fq_sample_mapping_table = "~{data_project}.~{default_dataset}.sample_info"
String fq_temp_table_dataset = "~{destination_project}.temp_tables"
String fq_destination_dataset = "~{destination_project}.~{destination_dataset}"
Expand All @@ -23,15 +23,15 @@ workflow GvsPrepareCallset {
String? docker
}

String docker_final = select_first([docker, "us.gcr.io/broad-dsde-methods/variantstore:ah_var_store_20210507"])
String docker_final = select_first([docker, "us.gcr.io/broad-dsde-methods/variantstore:ah_var_store_20210526"])

call PrepareCallsetTask {
input:
destination_cohort_table_name = destination_cohort_table_name,
destination_cohort_table_prefix = destination_cohort_table_prefix,
sample_names_to_extract = sample_names_to_extract,
query_project = query_project,
query_labels = query_labels,
fq_petvet_dataset = fq_petvet_dataset,
fq_cohort_sample_table = fq_cohort_sample_table,
fq_sample_mapping_table = fq_sample_mapping_table,
fq_temp_table_dataset = fq_temp_table_dataset,
fq_destination_dataset = fq_destination_dataset,
Expand All @@ -49,12 +49,12 @@ task PrepareCallsetTask {
}

input {
String destination_cohort_table_name
String destination_cohort_table_prefix
File sample_names_to_extract
String query_project
Array[String]? query_labels

String fq_petvet_dataset
String fq_cohort_sample_table
String fq_sample_mapping_table
String fq_temp_table_dataset
String fq_destination_dataset
Expand All @@ -73,8 +73,8 @@ task PrepareCallsetTask {
--fq_petvet_dataset ~{fq_petvet_dataset} \
--fq_temp_table_dataset ~{fq_temp_table_dataset} \
--fq_destination_dataset ~{fq_destination_dataset} \
--destination_table ~{destination_cohort_table_name} \
--fq_cohort_sample_names ~{fq_cohort_sample_table} \
--destination_cohort_table_prefix ~{destination_cohort_table_prefix} \
--sample_names_to_extract ~{sample_names_to_extract} \
--query_project ~{query_project} \
~{sep=" " query_label_args} \
--fq_sample_mapping_table ~{fq_sample_mapping_table} \
Expand Down
127 changes: 86 additions & 41 deletions scripts/variantstore/wdl/extract/create_cohort_extract_data_table.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-
import sys
import uuid
import time
import datetime

from concurrent.futures import ThreadPoolExecutor, as_completed
from google.cloud import bigquery
from google.cloud.bigquery.job import QueryJobConfig
from google.oauth2 import service_account
Expand All @@ -20,7 +19,6 @@
VET_TABLE_PREFIX = "vet_"
SAMPLES_PER_PARTITION = 4000

TEMP_TABLE_TTL = " OPTIONS( expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 HOUR)) "
FINAL_TABLE_TTL = ""
#FINAL_TABLE_TTL = " OPTIONS( expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 72 HOUR)) "

Expand All @@ -34,6 +32,7 @@
VET_DISTINCT_POS_TABLE = f"{output_table_prefix}_vet_distinct_pos"
PET_NEW_TABLE = f"{output_table_prefix}_pet_new"
VET_NEW_TABLE = f"{output_table_prefix}_vet_new"
EXTRACT_SAMPLE_TABLE = f"{output_table_prefix}_sample_names"

def utf8len(s):
return len(s.encode('utf-8'))
Expand Down Expand Up @@ -70,7 +69,6 @@ def execute_with_retry(label, sql):
print(f"COMPLETED ({time.time() - start} s, {3-len(retry_delay)} retries) - {label}")
return results
except Exception as err:

# if there are no retries left... raise
if (len(retry_delay) == 0):
raise err
Expand All @@ -85,19 +83,48 @@ def get_partition_range(i):

return { 'start': (i-1)*SAMPLES_PER_PARTITION + 1, 'end': i*SAMPLES_PER_PARTITION }

def get_samples_for_partition(cohort, i):
return [ s for s in cohort if s >= get_partition_range(i)['start'] and s <= get_partition_range(i)['end'] ]
def get_samples_for_partition(sample_ids, i):
return [ s for s in sample_ids if s >= get_partition_range(i)['start'] and s <= get_partition_range(i)['end'] ]

def split_lists(samples, n):
return [samples[i * n:(i + 1) * n] for i in range((len(samples) + n - 1) // n )]

def get_all_samples(fq_cohort_sample_names, fq_sample_mapping_table):
sql = f"select m.sample_id from `{fq_cohort_sample_names}` c JOIN `{fq_sample_mapping_table}` m ON (m.sample_name = c.sample_name)"
def load_sample_names(sample_names_to_extract, fq_temp_table_dataset):
schema = [ bigquery.SchemaField("sample_name", "STRING", mode="REQUIRED") ]
fq_sample_table = f"{fq_temp_table_dataset}.{EXTRACT_SAMPLE_TABLE}"

job_labels = client._default_query_job_config.labels
job_labels["gvs_query_name"] = "load-sample-names"

job_config = bigquery.LoadJobConfig(source_format=bigquery.SourceFormat.CSV, skip_leading_rows=0, schema=schema, labels=job_labels)

with open(sample_names_to_extract, "rb") as source_file:
job = client.load_table_from_file(source_file, fq_sample_table, job_config=job_config)

job.result() # Waits for the job to complete.

# setting the TTL needs to be done as a second API call
table = bigquery.Table(fq_sample_table, schema=schema)
expiration = datetime.datetime.utcnow() + datetime.timedelta(hours=TEMP_TABLE_TTL_HOURS)
table.expires = expiration
table = client.update_table(table, ["expires"])

return fq_sample_table

def get_all_sample_ids(fq_destination_table_samples):
sql = f"select sample_id from `{fq_destination_table_samples}`"

results = execute_with_retry("read cohort sample table", sql)
sample_ids = [row.sample_id for row in list(results)]
sample_ids.sort()
return sample_ids

def create_extract_samples_table(fq_destination_table_samples, fq_sample_name_table, fq_sample_mapping_table):
sql = f"CREATE OR REPLACE TABLE `{fq_destination_table_samples}` AS (" \
f"SELECT m.sample_id, m.sample_name FROM `{fq_sample_name_table}` s JOIN `{fq_sample_mapping_table}` m ON (s.sample_name = m.sample_name) )"

results = execute_with_retry("read cohort table", sql)
cohort = [row.sample_id for row in list(results)]
cohort.sort()
return cohort
results = execute_with_retry("create extract sample table", sql)
return results

def get_table_count(fq_pet_vet_dataset):
sql = f"SELECT MAX(CAST(SPLIT(table_name, '_')[OFFSET(1)] AS INT64)) max_table_number " \
Expand All @@ -106,15 +133,15 @@ def get_table_count(fq_pet_vet_dataset):
results = execute_with_retry("get max table", sql)
return int([row.max_table_number for row in list(results)][0])

def make_new_vet_union_all(fq_pet_vet_dataset, fq_temp_table_dataset, cohort):
def make_new_vet_union_all(fq_pet_vet_dataset, fq_temp_table_dataset, sample_ids):
def get_subselect(fq_vet_table, samples, id):
sample_stanza = ','.join([str(s) for s in samples])
sql = f" q_{id} AS (SELECT location, sample_id, ref, alt, call_GT, call_GQ, call_pl, QUALapprox, AS_QUALapprox from `{fq_vet_table}` WHERE sample_id IN ({sample_stanza})), "
return sql

subs = {}
for i in range(1, PET_VET_TABLE_COUNT+1):
partition_samples = get_samples_for_partition(cohort, i)
partition_samples = get_samples_for_partition(sample_ids, i)

# KCIBUL -- grr, should be fixed width
fq_vet_table = f"{fq_pet_vet_dataset}.{VET_TABLE_PREFIX}{i:03}"
Expand Down Expand Up @@ -163,7 +190,7 @@ def create_position_table(fq_temp_table_dataset, min_variant_samples):
JOB_IDS.add((f"create positions table {dest}", create_vet_distinct_pos_query.job_id))
return

def make_new_pet_union_all(fq_pet_vet_dataset, fq_temp_table_dataset, cohort):
def make_new_pet_union_all(fq_pet_vet_dataset, fq_temp_table_dataset, sample_ids):
def get_pet_subselect(fq_pet_table, samples, id):
sample_stanza = ','.join([str(s) for s in samples])
sql = f" q_{id} AS (SELECT p.location, p.sample_id, p.state from {fq_pet_table} p " \
Expand All @@ -172,7 +199,7 @@ def get_pet_subselect(fq_pet_table, samples, id):

subs = {}
for i in range(1, PET_VET_TABLE_COUNT+1):
partition_samples = get_samples_for_partition(cohort, i)
partition_samples = get_samples_for_partition(sample_ids, i)

# KCIBUL -- grr, should be fixed width
fq_pet_table = f"{fq_pet_vet_dataset}.{PET_TABLE_PREFIX}{i:03}"
Expand All @@ -195,10 +222,9 @@ def get_pet_subselect(fq_pet_table, samples, id):
return results


def populate_final_extract_table(fq_temp_table_dataset, fq_destination_dataset, destination_table, fq_sample_mapping_table):
dest = f"{fq_destination_dataset}.{destination_table}"
def populate_final_extract_table(fq_temp_table_dataset, fq_destination_table_data, fq_sample_mapping_table):
sql = f"""
CREATE OR REPLACE TABLE `{dest}`
CREATE OR REPLACE TABLE `{fq_destination_table_data}`
PARTITION BY RANGE_BUCKET(location, GENERATE_ARRAY(0, 26000000000000, 6500000000))
CLUSTER BY location
{FINAL_TABLE_TTL}
Expand Down Expand Up @@ -230,43 +256,48 @@ def populate_final_extract_table(fq_temp_table_dataset, fq_destination_dataset,
cohort_extract_final_query_job = client.query(sql, job_config=job_config)

cohort_extract_final_query_job.result()
JOB_IDS.add((f"insert final cohort table {dest}", cohort_extract_final_query_job.job_id))
JOB_IDS.add((f"insert final cohort table {fq_destination_table_data}", cohort_extract_final_query_job.job_id))
return

def make_extract_table(fq_pet_vet_dataset,
max_tables,
sample_names_to_extract,
fq_cohort_sample_names,
query_project,
query_labels,
fq_temp_table_dataset,
fq_destination_dataset,
destination_table,
destination_table_prefix,
min_variant_samples,
fq_sample_mapping_table,
sa_key_path,
temp_table_ttl_hours
):
try:
fq_destination_table_data = f"{fq_destination_dataset}.{destination_table_prefix}__DATA"
fq_destination_table_samples = f"{fq_destination_dataset}.{destination_table_prefix}__SAMPLES"

global client
# this is where a set of labels are being created for the cohort extract
query_labels_map = {}
query_labels_map["id"]= {output_table_prefix}
query_labels_map["gvs_tool_name"]= f"gvs_prepare_callset"
query_labels_map["id"]= output_table_prefix
query_labels_map["gvs_tool_name"]= "gvs_prepare_callset"

# query_labels is string that looks like 'key1=val1, key2=val2'
if len(query_labels) != 0:
if query_labels is not None and len(query_labels) != 0:
for query_label in query_labels:
kv = query_label.split("=", 2)
key = kv[0].strip().lower()
value = kv[1].strip().lower()
query_labels_map[key] = value

if not (bool(re.match(r"[a-z0-9_-]+$", key)) & bool(re.match(r"[a-z0-9_-]+$", value))):
raise ValueError(f"label key or value did not pass validation--format should be 'key1=val1, key2=val2'")

if not (bool(re.match(r"[a-z0-9_-]+$", key)) & bool(re.match(r"[a-z0-9_-]+$", value))):
raise ValueError(f"label key or value did not pass validation--format should be 'key1=val1, key2=val2'")
#Default QueryJobConfig will be merged into job configs passed in
#but if a specific default config is being updated (eg labels), new config must be added
#to the client._default_query_job_config that already exists
default_config = QueryJobConfig(labels=query_labels_map, priority="INTERACTIVE", use_query_cache=False )
default_config = QueryJobConfig(labels=query_labels_map, priority="INTERACTIVE", use_query_cache=False)

if sa_key_path:
credentials = service_account.Credentials.from_service_account_file(
Expand All @@ -285,55 +316,69 @@ def make_extract_table(fq_pet_vet_dataset,
global PET_VET_TABLE_COUNT
PET_VET_TABLE_COUNT = max_tables

global TEMP_TABLE_TTL_HOURS
TEMP_TABLE_TTL_HOURS = temp_table_ttl_hours

global TEMP_TABLE_TTL
TEMP_TABLE_TTL = f" OPTIONS( expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL {temp_table_ttl_hours} HOUR)) "
TEMP_TABLE_TTL = f" OPTIONS( expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL {TEMP_TABLE_TTL_HOURS} HOUR)) "

print(f"Using {PET_VET_TABLE_COUNT} PET tables in {fq_pet_vet_dataset}...")

cohort = get_all_samples(fq_cohort_sample_names, fq_sample_mapping_table)
print(f"Discovered {len(cohort)} samples in {fq_cohort_sample_names}...")
# if we have a file of sample names, load it into a temporary table
if (sample_names_to_extract):
fq_sample_name_table = load_sample_names(sample_names_to_extract, fq_temp_table_dataset)
else:
fq_sample_name_table = fq_cohort_sample_names

# At this point one way or the other we have a table of sample names in BQ,
# join it to the sample_info table to drive the extract
create_extract_samples_table(fq_destination_table_samples, fq_sample_name_table, fq_sample_mapping_table)

make_new_vet_union_all(fq_pet_vet_dataset, fq_temp_table_dataset, cohort)
# pull the sample ids back down
sample_ids = get_all_sample_ids(fq_destination_table_samples)
print(f"Discovered {len(sample_ids)} samples in {fq_destination_table_samples}...")

make_new_vet_union_all(fq_pet_vet_dataset, fq_temp_table_dataset, sample_ids)

create_position_table(fq_temp_table_dataset, min_variant_samples)
make_new_pet_union_all(fq_pet_vet_dataset, fq_temp_table_dataset, cohort)
populate_final_extract_table(fq_temp_table_dataset, fq_destination_dataset, destination_table, fq_sample_mapping_table)
make_new_pet_union_all(fq_pet_vet_dataset, fq_temp_table_dataset, sample_ids)
populate_final_extract_table(fq_temp_table_dataset, fq_destination_table_data, fq_destination_table_samples)
finally:
dump_job_stats()

print(f"\nFinal cohort extract written to {fq_destination_dataset}.{destination_table}\n")
print(f"\nFinal cohort extract data written to {fq_destination_table_data}\n")

if __name__ == '__main__':
parser = argparse.ArgumentParser(allow_abbrev=False, description='Extract a cohort from BigQuery Variant Store ')

parser.add_argument('--fq_petvet_dataset',type=str, help='project.dataset location of pet/vet data', required=True)
parser.add_argument('--fq_temp_table_dataset',type=str, help='project.dataset location where results should be stored', required=True)
parser.add_argument('--fq_destination_dataset',type=str, help='project.dataset location where results should be stored', required=True)
parser.add_argument('--destination_table',type=str, help='destination table', required=True)
parser.add_argument('--fq_cohort_sample_names',type=str, help='FQN of cohort table to extract, contains "sample_name" column', required=True)
parser.add_argument('--destination_cohort_table_prefix',type=str, help='prefix used for destination cohort extract tables (e.g. my_fantastic_cohort)', required=True)
parser.add_argument('--query_project',type=str, help='Google project where query should be executed', required=True)
parser.add_argument('--query_labels',type=str, action='append', help='Labels to put on the BQ query that will show up in the billing. Ex: --query_labels key1=value1 --query_labels key2=value2', required=False)
parser.add_argument('--min_variant_samples',type=int, help='Minimum variant samples at a site required to be emitted', required=False, default=0)
parser.add_argument('--fq_sample_mapping_table',type=str, help='Mapping table from sample_id to sample_name', required=True)
parser.add_argument('--sa_key_path',type=str, help='Path to json key file for SA', required=False)


parser.add_argument('--max_tables',type=int, help='Maximum number of PET/VET tables to consider', required=False, default=250)

parser.add_argument('--ttl',type=int, help='Temp table TTL in hours', required=False, default=72)
sample_args = parser.add_mutually_exclusive_group(required=True)
sample_args.add_argument('--sample_names_to_extract',type=str, help='File containing list of samples to extract, 1 per line')
sample_args.add_argument('--fq_cohort_sample_names',type=str, help='FQN of cohort table to extract, contains "sample_name" column')


# Execute the parse_args() method
args = parser.parse_args()

make_extract_table(args.fq_petvet_dataset,
args.max_tables,
args.sample_names_to_extract,
args.fq_cohort_sample_names,
args.query_project,
args.query_labels,
args.fq_temp_table_dataset,
args.fq_destination_dataset,
args.destination_table,
args.destination_cohort_table_prefix,
args.min_variant_samples,
args.fq_sample_mapping_table,
args.sa_key_path,
Expand Down

0 comments on commit f09214c

Please sign in to comment.