Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(susie_finemapper): fix in the fine-mapper in case of sum stat imputation is False #627

Merged
merged 2 commits into from
May 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 59 additions & 16 deletions src/gentropy/susie_finemapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,18 @@ def __init__(
ld_score_threshold=ld_score_threshold,
)
)
# Write result
result_logging["study_locus"].df.write.mode(session.write_mode).parquet(
output_path + "/" + study_locus_to_finemap
)
# Write log
result_logging["log"].to_parquet(
output_path_log + "/" + study_locus_to_finemap + ".parquet",
engine="pyarrow",
index=False
)

if result_logging is not None:
# Write result
result_logging["study_locus"].df.write.mode(session.write_mode).parquet(
output_path + "/" + study_locus_to_finemap
)
# Write log
result_logging["log"].to_parquet(
output_path_log + "/" + study_locus_to_finemap + ".parquet",
engine="pyarrow",
index=False,
)
else:
result = self.susie_finemapper_ss_gathered(
session=session,
Expand Down Expand Up @@ -1009,7 +1011,7 @@ def susie_finemapper_one_studylocus_row_v3_dev_ss_gathered(
purity_mean_r2_threshold: float = 0,
purity_min_r2_threshold: float = 0.25,
cs_lbf_thr: float = 2,
) -> dict[str, Any]:
) -> dict[str, Any] | None:
"""Susie fine-mapper function that uses study-locus row with collected locus, chromosome and position as inputs.

Args:
Expand All @@ -1032,7 +1034,7 @@ def susie_finemapper_one_studylocus_row_v3_dev_ss_gathered(
cs_lbf_thr (float): credible set logBF threshold for filtering credible sets, default is 2

Returns:
dict[str, Any]: dictionary with study locus, number of GWAS variants, number of LD variants, number of variants after merge, number of outliers, number of imputed variants, number of variants to fine-map
dict[str, Any] | None: dictionary with study locus, number of GWAS variants, number of LD variants, number of variants after merge, number of outliers, number of imputed variants, number of variants to fine-map, or None
"""
# PLEASE DO NOT REMOVE THIS LINE
pd.DataFrame.iteritems = pd.DataFrame.items
Expand Down Expand Up @@ -1077,6 +1079,11 @@ def susie_finemapper_one_studylocus_row_v3_dev_ss_gathered(
.filter(f.col("z").isNotNull())
)

# Remove ALL duplicated variants from GWAS DataFrame - we don't know which is correct
variant_counts = gwas_df.groupBy("variantId").count()
unique_variants = variant_counts.filter(f.col("count") == 1)
gwas_df = gwas_df.join(unique_variants, on="variantId", how="left_semi")

ld_index = (
GnomADLDMatrix()
.get_locus_index(
Expand All @@ -1097,14 +1104,50 @@ def susie_finemapper_one_studylocus_row_v3_dev_ss_gathered(
).cast("string"),
)
)

gnomad_ld = GnomADLDMatrix.get_numpy_matrix(
ld_index, gnomad_ancestry=major_population
# Remove ALL duplicated variants from ld_index DataFrame - we don't know which is correct
variant_counts = ld_index.groupBy("variantId").count()
unique_variants = variant_counts.filter(f.col("count") == 1)
ld_index = ld_index.join(unique_variants, on="variantId", how="left_semi").sort(
"idx"
)

if not run_sumstat_imputation:
# Filtering out the variants that are not in the LD matrix, we don't need them
gwas_index = gwas_df.join(
ld_index.select("variantId", "alleles", "idx"), on="variantId"
).sort("idx")
gwas_df = gwas_index.select(
"variantId",
"z",
"chromosome",
"position",
"beta",
"StandardError",
)
gwas_index = gwas_index.drop(
"z", "chromosome", "position", "beta", "StandardError"
)
if gwas_index.rdd.isEmpty():
logging.warning("No overlapping variants in the LD Index")
return None
gnomad_ld = GnomADLDMatrix.get_numpy_matrix(
gwas_index, gnomad_ancestry=major_population
)
else:
gwas_index = gwas_df.join(
ld_index.select("variantId", "alleles", "idx"), on="variantId"
).sort("idx")
if gwas_index.rdd.isEmpty():
logging.warning("No overlapping variants in the LD Index")
return None
gwas_index = ld_index
gnomad_ld = GnomADLDMatrix.get_numpy_matrix(
gwas_index, gnomad_ancestry=major_population
)

out = SusieFineMapperStep.susie_finemapper_from_prepared_dataframes(
GWAS_df=gwas_df,
ld_index=ld_index,
ld_index=gwas_index,
gnomad_ld=gnomad_ld,
L=max_causal_snps,
session=session,
Expand Down