Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dataset): add Dataset.filter to create a new dataset instance with filtered data #470

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/howto/python_api/d_inspect_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ The `df` attribute of a Dataset instance is key to interacting with and inspecti

This method displays the first 10 rows of your dataset, giving you a snapshot of your data's structure and content.

### Filter data

```python
--8<-- "src_snippets/howto/python_api/d_inspect_dataset.py:filter_dataset"
```

This method allows you to filter your data based on specific conditions, such as the value of a column. The application of any filter will create a new instance of the `Dataset` with the filtered data.

### Understand the schema

```python
Expand Down
13 changes: 13 additions & 0 deletions docs/src_snippets/howto/python_api/d_inspect_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
from pyspark.sql.types import StructType


def filter_dataset(summary_stats: SummaryStatistics) -> SummaryStatistics:
"""Docs to filter the `df` attribute of a dataset using Dataset.filter."""
# --8<-- [start:filter_dataset]
import pyspark.sql.functions as f

# Filter summary statistics to only include associations in chromosome 22
filtered = summary_stats.filter(condition=f.col("chromosome") == "22")
# --8<-- [end:filter_dataset]
return filtered


def interact_w_dataframe(summary_stats: SummaryStatistics) -> SummaryStatistics:
"""Docs to interact with the `df` attribute of a dataset."""
# --8<-- [start:print_dataframe]
Expand All @@ -22,6 +33,7 @@ def interact_w_dataframe(summary_stats: SummaryStatistics) -> SummaryStatistics:
# --8<-- [end:print_dataframe_schema]
return summary_stats


def get_dataset_schema(summary_stats: SummaryStatistics) -> StructType:
"""Docs to get the schema of a dataset."""
# --8<-- [start:get_dataset_schema]
Expand All @@ -30,6 +42,7 @@ def get_dataset_schema(summary_stats: SummaryStatistics) -> StructType:
# --8<-- [end:get_dataset_schema]
return schema


def write_data(summary_stats: SummaryStatistics) -> None:
"""Docs to write a dataset to disk."""
# --8<-- [start:write_parquet]
Expand Down
15 changes: 14 additions & 1 deletion src/gentropy/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from gentropy.common.schemas import flatten_schema

if TYPE_CHECKING:
from pyspark.sql import DataFrame
from pyspark.sql import Column, DataFrame
from pyspark.sql.types import StructType

from gentropy.common.session import Session
Expand Down Expand Up @@ -94,6 +94,19 @@ def from_parquet(
raise ValueError(f"Parquet file is empty: {path}")
return cls(_df=df, _schema=schema)

def filter(self: Self, condition: Column) -> Self:
"""Creates a new instance of a Dataset with the DataFrame filtered by the condition.

Args:
condition (Column): Condition to filter the DataFrame

Returns:
Self: Filtered Dataset
"""
df = self._df.filter(condition)
class_constructor = self.__class__
return class_constructor(_df=df, _schema=class_constructor.get_schema())

def validate_schema(self: Dataset) -> None:
"""Validate DataFrame schema against expected class schema.

Expand Down
14 changes: 14 additions & 0 deletions tests/dataset/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Test Dataset class."""
from __future__ import annotations

import pyspark.sql.functions as f
import pytest
from gentropy.dataset.dataset import Dataset
from gentropy.dataset.study_index import StudyIndex
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StructField, StructType

Expand Down Expand Up @@ -42,3 +44,15 @@ def _setup(self: TestCoalesceAndRepartition, spark: SparkSession) -> None:
),
_schema=MockDataset.get_schema(),
)


def test_dataset_filter(mock_study_index: StudyIndex) -> None:
"""Test Dataset.filter."""
expected_filter_value = "gwas"
condition = f.col("studyType") == expected_filter_value

filtered = mock_study_index.filter(condition)
assert (
filtered.df.select("studyType").distinct().toPandas()["studyType"].to_list()[0]
== expected_filter_value
), "Filtering failed."
6 changes: 6 additions & 0 deletions tests/docs/test_inspect_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
from pyspark.sql.types import StructType

from docs.src_snippets.howto.python_api.d_inspect_dataset import (
filter_dataset,
get_dataset_schema,
interact_w_dataframe,
)


def test_filter_dataset(mock_summary_statistics: SummaryStatistics) -> None:
"""Test filter_dataset returns a SummaryStatistics."""
assert isinstance(filter_dataset(mock_summary_statistics), SummaryStatistics)


def test_interact_w_dataframe(mock_summary_statistics: SummaryStatistics) -> None:
"""Test interact_w_dataframe returns a SummaryStatistics."""
assert isinstance(interact_w_dataframe(mock_summary_statistics), SummaryStatistics)
Expand Down