Skip to content

Commit

Permalink
refactor: stop inheriting datasets in parsers (#313)
Browse files Browse the repository at this point in the history
* refactor: stop inheriting datasets in parsers

* fix: typing issue

* refactor: include datasets in datasources

* test: fix incorrect import

* test: doctest function calls fixed

* test: doctest function calls fixed in studyindex

---------

Co-authored-by: Daniel Suveges <daniel.suveges@protonmail.com>
  • Loading branch information
d0choa and DSuveges authored Dec 12, 2023
1 parent eea2a4c commit ee73572
Show file tree
Hide file tree
Showing 24 changed files with 263 additions and 260 deletions.
6 changes: 5 additions & 1 deletion docs/python_api/datasource/gwas_catalog/associations.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@
title: Associations
---

::: otg.datasource.gwas_catalog.associations.GWASCatalogAssociations
::: otg.datasource.gwas_catalog.associations.GWASCatalogCuratedAssociationsParser

---

::: otg.datasource.gwas_catalog.associations.StudyLocusGWASCatalog
6 changes: 5 additions & 1 deletion docs/python_api/datasource/gwas_catalog/study_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@
title: Study Index
---

::: otg.datasource.gwas_catalog.study_index.GWASCatalogStudyIndex
::: otg.datasource.gwas_catalog.study_index.StudyIndexGWASCatalogParser

---

::: otg.datasource.gwas_catalog.study_index.StudyIndexGWASCatalog
8 changes: 4 additions & 4 deletions src/otg/dataset/study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _align_overlapping_tags(
)

@staticmethod
def _update_quality_flag(
def update_quality_flag(
qc: Column, flag_condition: Column, flag_text: StudyLocusQualityCheck
) -> Column:
"""Update the provided quality control list with a new flag if condition is met.
Expand Down Expand Up @@ -410,7 +410,7 @@ def clump(self: StudyLocus) -> StudyLocus:
)
.withColumn(
"qualityControls",
StudyLocus._update_quality_flag(
StudyLocus.update_quality_flag(
f.col("qualityControls"),
f.col("is_lead_linked"),
StudyLocusQualityCheck.LD_CLUMPED,
Expand All @@ -430,7 +430,7 @@ def _qc_unresolved_ld(
"""
self.df = self.df.withColumn(
"qualityControls",
self._update_quality_flag(
self.update_quality_flag(
f.col("qualityControls"),
f.col("ldSet").isNull(),
StudyLocusQualityCheck.UNRESOLVED_LD,
Expand All @@ -450,7 +450,7 @@ def _qc_no_population(self: StudyLocus) -> StudyLocus:

self.df = self.df.withColumn(
"qualityControls",
self._update_quality_flag(
self.update_quality_flag(
f.col("qualityControls"),
f.col("ldPopulationStructure").isNull(),
StudyLocusQualityCheck.NO_POPULATION,
Expand Down
18 changes: 9 additions & 9 deletions src/otg/datasource/eqtl_catalogue/study_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pyspark.sql.column import Column


class EqtlCatalogueStudyIndex(StudyIndex):
class EqtlCatalogueStudyIndex:
"""Study index dataset from eQTL Catalogue."""

@staticmethod
Expand Down Expand Up @@ -92,29 +92,29 @@ def _all_attributes() -> List[Column]:
def from_source(
cls: type[EqtlCatalogueStudyIndex],
eqtl_studies: DataFrame,
) -> EqtlCatalogueStudyIndex:
) -> StudyIndex:
"""Ingest study level metadata from eQTL Catalogue.
Args:
eqtl_studies (DataFrame): ingested but unprocessed eQTL Catalogue studies.
Returns:
EqtlCatalogueStudyIndex: preliminary processed study index for eQTL Catalogue studies.
StudyIndex: preliminary processed study index for eQTL Catalogue studies.
"""
return EqtlCatalogueStudyIndex(
return StudyIndex(
_df=eqtl_studies.select(*cls._all_attributes()).withColumn(
"ldPopulationStructure",
cls.aggregate_and_map_ancestries(f.col("discoverySamples")),
StudyIndex.aggregate_and_map_ancestries(f.col("discoverySamples")),
),
_schema=cls.get_schema(),
_schema=StudyIndex.get_schema(),
)

@classmethod
def add_gene_id_column(
cls: type[EqtlCatalogueStudyIndex],
study_index_df: DataFrame,
summary_stats_df: DataFrame,
) -> EqtlCatalogueStudyIndex:
) -> StudyIndex:
"""Add a geneId column to the study index and explode.
While the original list contains one entry per tissue, what we consider as a single study is one mini-GWAS for
Expand All @@ -127,7 +127,7 @@ def add_gene_id_column(
summary_stats_df (DataFrame): summary statistics dataframe for eQTL Catalogue data.
Returns:
EqtlCatalogueStudyIndex: final study index for eQTL Catalogue studies.
StudyIndex: final study index for eQTL Catalogue studies.
"""
partial_to_full_study_id = (
summary_stats_df.select(f.col("studyId"))
Expand All @@ -148,4 +148,4 @@ def add_gene_id_column(
.withColumn("geneId", f.regexp_extract(f.col("studyId"), r".*_([\_]+)", 1))
.drop("fullStudyId")
)
return EqtlCatalogueStudyIndex(_df=study_index_df, _schema=cls.get_schema())
return StudyIndex(_df=study_index_df, _schema=StudyIndex.get_schema())
10 changes: 5 additions & 5 deletions src/otg/datasource/eqtl_catalogue/summary_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


@dataclass
class EqtlCatalogueSummaryStats(SummaryStatistics):
class EqtlCatalogueSummaryStats:
"""Summary statistics dataset for eQTL Catalogue."""

@staticmethod
Expand Down Expand Up @@ -49,14 +49,14 @@ def _full_study_id_regexp() -> Column:
def from_source(
cls: type[EqtlCatalogueSummaryStats],
summary_stats_df: DataFrame,
) -> EqtlCatalogueSummaryStats:
) -> SummaryStatistics:
"""Ingests all summary stats for all eQTL Catalogue studies.
Args:
summary_stats_df (DataFrame): an ingested but unprocessed summary statistics dataframe from eQTL Catalogue.
Returns:
EqtlCatalogueSummaryStats: a processed summary statistics dataframe for eQTL Catalogue.
SummaryStatistics: a processed summary statistics dataframe for eQTL Catalogue.
"""
processed_summary_stats_df = (
summary_stats_df.select(
Expand Down Expand Up @@ -87,7 +87,7 @@ def from_source(
)

# Initialise a summary statistics object.
return cls(
return SummaryStatistics(
_df=processed_summary_stats_df,
_schema=cls.get_schema(),
_schema=SummaryStatistics.get_schema(),
)
12 changes: 6 additions & 6 deletions src/otg/datasource/finngen/study_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pyspark.sql import DataFrame


class FinnGenStudyIndex(StudyIndex):
class FinnGenStudyIndex:
"""Study index dataset from FinnGen.
The following information is aggregated/extracted:
Expand All @@ -31,7 +31,7 @@ def from_source(
finngen_release_prefix: str,
finngen_summary_stats_url_prefix: str,
finngen_summary_stats_url_suffix: str,
) -> FinnGenStudyIndex:
) -> StudyIndex:
"""This function ingests study level metadata from FinnGen.
Args:
Expand All @@ -41,9 +41,9 @@ def from_source(
finngen_summary_stats_url_suffix (str): URL prefix suffix for summary statistics location.
Returns:
FinnGenStudyIndex: Parsed and annotated FinnGen study table.
StudyIndex: Parsed and annotated FinnGen study table.
"""
return FinnGenStudyIndex(
return StudyIndex(
_df=finngen_studies.select(
f.concat(f.lit(f"{finngen_release_prefix}_"), f.col("phenocode")).alias(
"studyId"
Expand Down Expand Up @@ -73,7 +73,7 @@ def from_source(
).alias("summarystatsLocation"),
).withColumn(
"ldPopulationStructure",
cls.aggregate_and_map_ancestries(f.col("discoverySamples")),
StudyIndex.aggregate_and_map_ancestries(f.col("discoverySamples")),
),
_schema=cls.get_schema(),
_schema=StudyIndex.get_schema(),
)
10 changes: 5 additions & 5 deletions src/otg/datasource/finngen/summary_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@


@dataclass
class FinnGenSummaryStats(SummaryStatistics):
class FinnGenSummaryStats:
"""Summary statistics dataset for FinnGen."""

@classmethod
def from_source(
cls: type[FinnGenSummaryStats],
summary_stats_df: DataFrame,
) -> FinnGenSummaryStats:
) -> SummaryStatistics:
"""Ingests all summary statst for all FinnGen studies.
Args:
summary_stats_df (DataFrame): Raw summary statistics dataframe
Returns:
FinnGenSummaryStats: Processed summary statistics dataset
SummaryStatistics: Processed summary statistics dataset
"""
processed_summary_stats_df = (
summary_stats_df
Expand Down Expand Up @@ -64,7 +64,7 @@ def from_source(
)

# Initializing summary statistics object:
return cls(
return SummaryStatistics(
_df=processed_summary_stats_df,
_schema=cls.get_schema(),
_schema=SummaryStatistics.get_schema(),
)
Loading

0 comments on commit ee73572

Please sign in to comment.