Skip to content

Commit

Permalink
Merge branch 'main' into il-clump-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
d0choa authored Dec 1, 2023
2 parents 893ceaa + 793a58b commit fc4b33e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/otg/datasource/gwas_catalog/study_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _resolve_study_id(study_id: Column, sub_study_description: Column) -> Column
"""
split_w = Window.partitionBy(study_id).orderBy(sub_study_description)
row_number = f.dense_rank().over(split_w)
substudy_count = f.count(row_number).over(split_w)
substudy_count = f.approx_count_distinct(row_number).over(split_w)
return f.when(substudy_count == 1, study_id).otherwise(
f.concat_ws("_", study_id, row_number)
)
Expand Down
70 changes: 70 additions & 0 deletions tests/datasource/gwas_catalog/test_gwas_catalog_study_splitter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
"""Tests GWAS Catalog study splitter."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import pyspark.sql.functions as f
import pytest

from otg.datasource.gwas_catalog.associations import GWASCatalogAssociations
from otg.datasource.gwas_catalog.study_index import GWASCatalogStudyIndex
from otg.datasource.gwas_catalog.study_splitter import GWASCatalogStudySplitter

if TYPE_CHECKING:
from pyspark.sql import SparkSession


def test_gwas_catalog_splitter_split(
mock_study_index_gwas_catalog: GWASCatalogStudyIndex,
Expand All @@ -17,3 +25,65 @@ def test_gwas_catalog_splitter_split(

assert isinstance(d1, GWASCatalogStudyIndex)
assert isinstance(d2, GWASCatalogAssociations)


@pytest.mark.parametrize(
"observed, expected",
[
# Test 1 - it shouldn't split
(
# observed - 2 associations with the same subStudy annotation
[
(
"varA",
"GCST003436",
"Endometrial cancer|no_pvalue_text|EFO_1001512",
),
(
"varB",
"GCST003436",
"Endometrial cancer|no_pvalue_text|EFO_1001512",
),
],
# expected - 2 associations with the same unsplit updatedStudyId
[
("GCST003436",),
("GCST003436",),
],
),
# Test 2 - it should split
(
# observed - 2 associations with the different subStudy annotation
[
(
"varA",
"GCST003436",
"Endometrial cancer|no_pvalue_text|EFO_1001512",
),
(
"varB",
"GCST003436",
"Uterine carcinoma|no_pvalue_text|EFO_0002919",
),
],
# expected - 2 associations with the same unsplit updatedStudyId
[
("GCST003436",),
("GCST003436_2",),
],
),
],
)
def test__resolve_study_id(
spark: SparkSession, observed: list[Any], expected: list[Any]
) -> None:
"""Test _resolve_study_id."""
observed_df = spark.createDataFrame(
observed, schema=["variantId", "studyId", "subStudyDescription"]
).select(
GWASCatalogStudySplitter._resolve_study_id(
f.col("studyId"), f.col("subStudyDescription").alias("updatedStudyId")
)
)
expected_df = spark.createDataFrame(expected, schema=["updatedStudyId"])
assert observed_df.collect() == expected_df.collect()

0 comments on commit fc4b33e

Please sign in to comment.