diff --git a/gnomad_constraint/pipeline/constraint_pipeline.py b/gnomad_constraint/pipeline/constraint_pipeline.py index 97776922..edfaa6c2 100644 --- a/gnomad_constraint/pipeline/constraint_pipeline.py +++ b/gnomad_constraint/pipeline/constraint_pipeline.py @@ -244,6 +244,7 @@ def main(args): version = args.version test = args.test overwrite = args.overwrite + skip_downsamplings = args.skip_downsamplings max_af = args.max_af pops = args.pops @@ -262,6 +263,10 @@ def main(args): ) logger.info("The following downsamplings will be used: %s", downsamplings) + # If pops not specified, set to empty Tuple + if not pops: + pops = () + # Drop chromosome Y from version v4.0 (can add back in when obtain chrY # methylation data). if int(version[0]) >= 4: @@ -422,10 +427,15 @@ def main(args): "mane_select" if version_4_and_above else "canonical" ), # Switch to using MANE Select transcripts rather than canonical for gnomAD v4 and later versions. global_annotation="training_dataset_params", + skip_downsamplings=skip_downsamplings, ) if use_v2_release_mutation_ht: op_ht = op_ht.annotate_globals(use_v2_release_mutation_ht=True) - op_ht.write(getattr(res, f"train_{r}_ht").path, overwrite=overwrite) + # op_ht.write(getattr(res, f"train_{r}_ht").path, overwrite=overwrite) + op_ht.write( + "gs://gnomad-kristen/constraint/gen_anc/train.ht", + overwrite=overwrite, + ) logger.info("Done with creating training dataset.") if args.build_models: @@ -436,7 +446,10 @@ def main(args): # chromosome X, and chromosome Y. for r in regions: # TODO: Remove repartition once partition_hint bugs are resolved. - training_ht = getattr(res, f"train_{r}_ht").ht() + # training_ht = getattr(res, f"train_{r}_ht").ht() + training_ht = hl.read_table( + "gs://gnomad-kristen/constraint/gen_anc/train.ht" + ) training_ht = training_ht.repartition(args.training_set_partition_hint) logger.info("Building %s plateau and coverage models...", r) @@ -450,15 +463,16 @@ def main(args): ) hl.experimental.write_expression( plateau_models, - getattr(res, f"model_{r}_plateau").path, + "gs://gnomad-kristen/constraint/gen_anc/plateau_models.he", + # getattr(res, f"model_{r}_plateau").path, overwrite=overwrite, ) - if not args.skip_coverage_model: - hl.experimental.write_expression( - coverage_model, - getattr(res, f"model_{r}_coverage").path, - overwrite=overwrite, - ) + # if not args.skip_coverage_model: + # hl.experimental.write_expression( + # coverage_model, + # getattr(res, f"model_{r}_coverage").path, + # overwrite=overwrite, + # ) logger.info("Done building %s models.", r) if args.apply_models: @@ -486,7 +500,10 @@ def main(args): exome_ht=getattr(res, f"preprocessed_{r}_exomes_ht").ht(), context_ht=getattr(res, f"preprocessed_{r}_context_ht").ht(), mutation_ht=mutation_ht, - plateau_models=getattr(res, f"model_{r}_plateau").he(), + plateau_models=hl.experimental.read_expression( + "gs://gnomad-kristen/constraint/gen_anc/plateau_models.he" + ), + # plateau_models=getattr(res, f"model_{r}_plateau").he(), coverage_model=( getattr(res, "model_autosome_par_coverage").he() if not args.skip_coverage_model @@ -495,6 +512,7 @@ def main(args): max_af=max_af, pops=pops, downsamplings=downsamplings, + skip_downsamplings=skip_downsamplings, obs_pos_count_partition_hint=args.apply_obs_pos_count_partition_hint, expected_variant_partition_hint=args.apply_expected_variant_partition_hint, custom_vep_annotation=custom_vep_annotation, @@ -509,7 +527,8 @@ def main(args): ) if use_v2_release_mutation_ht: oe_ht = oe_ht.annotate_globals(use_v2_release_mutation_ht=True) - oe_ht.write(getattr(res, f"apply_{r}_ht").path, overwrite=overwrite) + # oe_ht.write(getattr(res, f"apply_{r}_ht").path, overwrite=overwrite) + oe_ht.write("gs://gnomad-kristen/constraint/gen_anc/apply.ht") logger.info( "Done computing expected variant count and observed:expected ratio." @@ -992,6 +1011,11 @@ def main(args): help="Export constraint metrics to tsv file.", action="store_true", ) + parser.add_argument( + "--skip-downsamplings", + help="Whether to skip downsamplings when 'pops' is specified.", + action="store_true", + ) compute_constraint_args._group_actions.append(populations) diff --git a/gnomad_constraint/utils/constraint.py b/gnomad_constraint/utils/constraint.py index 2f1701ab..819bec0c 100644 --- a/gnomad_constraint/utils/constraint.py +++ b/gnomad_constraint/utils/constraint.py @@ -1,4 +1,5 @@ """Script containing utility functions used in the constraint pipeline.""" + import logging from typing import Dict, List, Optional, Tuple @@ -16,7 +17,7 @@ compute_pli, count_variants_by_group, get_constraint_flags, - get_downsampling_freq_indices, + get_pop_freq_indices, oe_aggregation_expr, oe_confidence_interval, trimer_from_heptamer, @@ -182,6 +183,7 @@ def create_observed_and_possible_ht( low_coverage_filter: int = None, transcript_for_synonymous_filter: str = None, global_annotation: Optional[str] = None, + skip_downsamplings: bool = False, ) -> hl.Table: """ Count the observed variants and possible variants by substitution, context, methylation level, and additional `grouping`. @@ -238,6 +240,7 @@ def create_observed_and_possible_ht( :param global_annotation: The annotation name to use as a global StructExpression annotation containing input parameter values. If no value is supplied, this global annotation will not be added. Default is None. + :param skip_downsamplings: Whether or not to skip pulling the downsampling data. :return: Table with observed variant and possible variant count. """ if low_coverage_filter is not None: @@ -292,6 +295,7 @@ def create_observed_and_possible_ht( count_downsamplings=pops, use_table_group_by=True, max_af=max_af, + skip_downsamplings=skip_downsamplings, ) # TODO: Remove repartition once partition_hint bugs are resolved. @@ -353,6 +357,7 @@ def apply_models( high_cov_definition: int = COVERAGE_CUTOFF, low_coverage_filter: int = None, use_mane_select: bool = True, + skip_downsamplings: bool = False, ) -> hl.Table: """ Compute the expected number of variants and observed:expected ratio using plateau models and coverage model. @@ -426,6 +431,7 @@ def apply_models( :param use_mane_select: Use MANE Select transcripts in grouping. Only used when `custom_vep_annotation` is set to 'transcript_consequences'. Default is True. + :param skip_downsamplings: Whether or not to skip pulling the downsampling data. :return: Table with `expected_variants` (expected variant counts) and `obs_exp` (observed:expected ratio) annotations. @@ -477,6 +483,7 @@ def apply_models( partition_hint=obs_pos_count_partition_hint, filter_coverage_over_0=True, transcript_for_synonymous_filter=None, + skip_downsamplings=skip_downsamplings, ) # NOTE: In v2 ht.mu_snp was incorrectly multiplied here by possible_variants, but this multiplication has now been moved, @@ -524,15 +531,15 @@ def apply_models( # Store which downsamplings are obtained for each pop in a # downsampling_meta dictionary. - ds = hl.eval(get_downsampling_freq_indices(ht.freq_meta, pop=pop)) + ds = hl.eval(get_pop_freq_indices(ht.freq_meta, pop=pop)) key_names = {key for _, meta_dict in ds for key in meta_dict.keys()} genetic_ancestry_label = "gen_anc" if "gen_anc" in key_names else "pop" downsampling_meta[pop] = [ - x[1]["downsampling"] + x[1].get("downsampling", "all") for x in ds - if (x[1][genetic_ancestry_label] == pop) - & ( - int(x[1]["downsampling"]) in downsamplings + if x[1][genetic_ancestry_label] == pop + and ( + int(x[1].get("downsampling", 0)) in downsamplings if downsamplings is not None else True ) @@ -897,9 +904,8 @@ def compute_constraint_metrics( # `annotation_dict` stats the rule of filtration for each annotation. annotation_dict = { # Filter to classic LoF annotations with LOFTEE HC or LC. - "lof_hc_lc": hl.literal(set(classic_lof_annotations)).contains( - ht.annotation - ) & ((ht.modifier == "HC") | (ht.modifier == "LC")), + "lof_hc_lc": hl.literal(set(classic_lof_annotations)).contains(ht.annotation) + & ((ht.modifier == "HC") | (ht.modifier == "LC")), # Filter to LoF annotations with LOFTEE HC. "lof": ht.modifier == "HC", # Filter to missense variants. @@ -1009,6 +1015,29 @@ def compute_constraint_metrics( z_raw=raw_z_expr, ) + gen_anc_lower_struct = {} + gen_anc_upper_struct = {} + gen_anc_z_raw_struct = {} + + # Calculate lower and upper cis, and raw z scores for each pop, excluding downsamplings. + for pop in pops: + obs_expr = ht[ann]["gen_anc_obs"][pop][0] + exp_expr = ht[ann]["gen_anc_exp"][pop][0] + oe_ci_expr = oe_confidence_interval(obs_expr, exp_expr) + raw_z_expr = calculate_raw_z_score(obs_expr, exp_expr) + + lower_struct[pop] = oe_ci_expr.lower + upper_struct[pop] = oe_ci_expr.upper + gen_anc_z_raw_struct[pop] = raw_z_expr + + # Annotate the table with the structs. + ann_expr[ann] = ann_expr[ann].annotate( + gen_anc_oe_ci=hl.struct( + lower=hl.struct(**lower_struct), upper=hl.struct(**upper_struct) + ), + gen_anc_z_raw=hl.struct(**gen_anc_z_raw_struct), + ) + ann_expr["constraint_flags"] = add_filters_expr(filters=constraint_flags_expr) ht = ht.annotate(**ann_expr) ht = ht.checkpoint(