Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cohortdefiners return validatedframes #711

Merged
merged 7 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions psycop/common/cohort_definition.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,60 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Protocol, runtime_checkable

import polars as pl
from wasabi import Printer

from psycop.common.global_utils.pydantic_basemodel import PSYCOPBaseModel
from psycop.common.types.validated_frame import ValidatedFrame
from psycop.common.types.validator_rules import (
ColumnExistsRule,
ColumnTypeRule,
ValidatorRule,
)


@dataclass(frozen=True)
class PredictionTimeFrame(ValidatedFrame[pl.DataFrame]):
"""ValidatedFrame with extra validation for prediction times"""

frame: pl.DataFrame

entity_id_col_name: str = "dw_ek_borger"
entity_id_col_rules: Sequence[ValidatorRule] = (
ColumnExistsRule(),
sarakolding marked this conversation as resolved.
Show resolved Hide resolved
sarakolding marked this conversation as resolved.
Show resolved Hide resolved
ColumnTypeRule(expected_type=pl.Int64),
)

timestamp_col_name: str = "timestamp"
timestamp_col_rules: Sequence[ValidatorRule] = (
ColumnExistsRule(),
sarakolding marked this conversation as resolved.
Show resolved Hide resolved
ColumnTypeRule(expected_type=pl.Datetime),
)

allow_extra_columns: bool = True


@dataclass(frozen=True)
class OutcomeTimestampFrame(ValidatedFrame[pl.DataFrame]):
"""ValidatedFrame with extra validation for prediction times"""

frame: pl.DataFrame

entity_id_col_name: str = "dw_ek_borger"
entity_id_col_rules: Sequence[ValidatorRule] = (
ColumnExistsRule(),
sarakolding marked this conversation as resolved.
Show resolved Hide resolved
ColumnTypeRule(expected_type=pl.Int64),
)

timestamp_col_name: str = "timestamp"
timestamp_col_rules: Sequence[ValidatorRule] = (
ColumnExistsRule(),
sarakolding marked this conversation as resolved.
Show resolved Hide resolved
ColumnTypeRule(expected_type=pl.Datetime),
)

allow_extra_columns: bool = True


@runtime_checkable
Expand Down Expand Up @@ -33,8 +82,9 @@ def n_dropped_ids(self) -> int:
return self.n_ids_before - self.n_ids_after


class FilteredPredictionTimeBundle(PSYCOPBaseModel):
prediction_times: pl.DataFrame
@dataclass(frozen=True)
class FilteredPredictionTimeBundle:
prediction_times: PredictionTimeFrame
filter_steps: list[StepDelta]


Expand All @@ -46,7 +96,7 @@ def get_filtered_prediction_times_bundle() -> FilteredPredictionTimeBundle:

@staticmethod
@abstractmethod
def get_outcome_timestamps() -> pl.DataFrame:
def get_outcome_timestamps() -> OutcomeTimestampFrame:
...


Expand Down Expand Up @@ -96,6 +146,6 @@ def filter_prediction_times(
prediction_times = prediction_times.drop("date_of_birth")

return FilteredPredictionTimeBundle(
prediction_times=prediction_times.collect(),
prediction_times=PredictionTimeFrame(frame=prediction_times.collect()),
filter_steps=stepdeltas,
)
2 changes: 1 addition & 1 deletion psycop/common/feature_generation/loaders/raw/load_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SplitFrame(ValidatedFrame[pl.LazyFrame]):
id_col_name: str = "dw_ek_borger"
id_col_rules: Sequence[ValidatorRule] = (
ColumnExistsRule(),
ColumnTypeRule(expected_type=pl.Utf8),
ColumnTypeRule(expected_type=pl.Int64),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ def create_prediction_times(
lookahead: dt.timedelta,
) -> tuple[PredictionTime, ...]:
outcome_timestamps = self._polars_dataframe_to_patient_timestamp_mapping(
dataframe=self.cohort_definer.get_outcome_timestamps(),
dataframe=self.cohort_definer.get_outcome_timestamps().frame,
id_col_name="dw_ek_borger",
patient_timestamp_col_name="timestamp",
)

naive_prediction_times = (
self.cohort_definer.get_filtered_prediction_times_bundle().prediction_times
self.cohort_definer.get_filtered_prediction_times_bundle().prediction_times.frame
).lazy()
prediction_times_for_split = self.split_filter.apply(
naive_prediction_times,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import datetime as dt

import polars as pl

from psycop.common.cohort_definition import CohortDefiner, FilteredPredictionTimeBundle
from psycop.common.cohort_definition import (
CohortDefiner,
FilteredPredictionTimeBundle,
OutcomeTimestampFrame,
PredictionTimeFrame,
)
from psycop.common.data_structures.test_patient import get_test_patient
from psycop.common.feature_generation.sequences.prediction_times_from_cohort import (
PredictionTimesFromCohort,
Expand All @@ -23,18 +26,18 @@ def get_filtered_prediction_times_bundle() -> FilteredPredictionTimeBundle:
""",
)
return FilteredPredictionTimeBundle(
prediction_times=df,
prediction_times=PredictionTimeFrame(df),
filter_steps=[],
)

@staticmethod
def get_outcome_timestamps() -> pl.DataFrame:
def get_outcome_timestamps() -> OutcomeTimestampFrame:
df = str_to_pl_df(
"""dw_ek_borger,timestamp
1,2021-01-02
""",
)
return df
return OutcomeTimestampFrame(frame=df)


def test_polars_dataframe_to_dict():
Expand Down
4 changes: 2 additions & 2 deletions psycop/common/test_cohort_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def test_filter_prediction_times():
prediction_times = str_to_pl_df(
"""
entity_id, timestamp,
dw_ek_borger, timestamp,
1, 2020-01-01,
1, 2019-01-01, # Filtered because of timestamp in filter 1
1, 2018-01-01, # Filtered because of timestamp in filter 2
Expand All @@ -36,4 +36,4 @@ def apply(self, df: pl.LazyFrame) -> pl.LazyFrame:
entity_id_col_name="entity_id",
sarakolding marked this conversation as resolved.
Show resolved Hide resolved
)

assert len(filtered.prediction_times) == 1
assert len(filtered.prediction_times.frame) == 1
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from psycop.common.cohort_definition import (
CohortDefiner,
FilteredPredictionTimeBundle,
OutcomeTimestampFrame,
filter_prediction_times,
)
from psycop.common.feature_generation.loaders.raw.load_visits import (
Expand Down Expand Up @@ -41,8 +42,8 @@ def get_filtered_prediction_times_bundle() -> FilteredPredictionTimeBundle:
)

@staticmethod
def get_outcome_timestamps() -> pl.DataFrame:
return pl.from_pandas(get_first_cancer_diagnosis())
def get_outcome_timestamps() -> OutcomeTimestampFrame:
return OutcomeTimestampFrame(frame=pl.from_pandas(get_first_cancer_diagnosis()))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion psycop/projects/cancer/feature_generation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

init_wandb_and_generate_feature_set(
project_info=get_cancer_project_info(),
eligible_prediction_times=CancerCohortDefiner.get_filtered_prediction_times_bundle().prediction_times.to_pandas(),
eligible_prediction_times=CancerCohortDefiner.get_filtered_prediction_times_bundle().prediction_times.frame.to_pandas(),
feature_specs=get_cancer_feature_specifications(),
generate_in_chunks=True,
chunksize=10,
Expand Down
4 changes: 2 additions & 2 deletions psycop/projects/cancer/feature_generation/specify_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _get_outcome_specs(self) -> list[OutcomeSpec]:
return [
OutcomeSpec(
feature_base_name="first_cancer_diagnosis",
timeseries_df=CancerCohortDefiner.get_outcome_timestamps().to_pandas(),
timeseries_df=CancerCohortDefiner.get_outcome_timestamps().frame.to_pandas(),
lookahead_days=365,
aggregation_fn=maximum,
fallback=0,
Expand All @@ -134,7 +134,7 @@ def _get_outcome_specs(self) -> list[OutcomeSpec]:
return OutcomeGroupSpec(
named_dataframes=[
NamedDataframe(
df=CancerCohortDefiner.get_outcome_timestamps().to_pandas(),
df=CancerCohortDefiner.get_outcome_timestamps().frame.to_pandas(),
name="first_cancer_diagnosis",
),
],
Expand Down
2 changes: 1 addition & 1 deletion psycop/projects/cancer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
if __name__ == "__main__":
feature_set_path = init_wandb_and_generate_feature_set(
project_info=get_cancer_project_info(),
eligible_prediction_times=CancerCohortDefiner.get_filtered_prediction_times_bundle().prediction_times.to_pandas(),
eligible_prediction_times=CancerCohortDefiner.get_filtered_prediction_times_bundle().prediction_times.frame.to_pandas(),
feature_specs=get_cancer_feature_specifications(),
generate_in_chunks=True,
chunksize=10,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from psycop.common.cohort_definition import (
CohortDefiner,
FilteredPredictionTimeBundle,
OutcomeTimestampFrame,
filter_prediction_times,
)
from psycop.common.feature_generation.loaders.raw.load_visits import (
Expand Down Expand Up @@ -45,11 +46,13 @@ def get_filtered_prediction_times_bundle() -> FilteredPredictionTimeBundle:
return result

@staticmethod
def get_outcome_timestamps() -> pl.DataFrame:
return (
pl.from_pandas(get_first_clozapine_prescription())
.with_columns(value=pl.lit(1))
.select(["dw_ek_borger", "timestamp", "value"])
def get_outcome_timestamps() -> OutcomeTimestampFrame:
return OutcomeTimestampFrame(
frame=(
pl.from_pandas(get_first_clozapine_prescription())
.with_columns(value=pl.lit(1))
.select(["dw_ek_borger", "timestamp", "value"])
),
)


Expand Down
2 changes: 1 addition & 1 deletion psycop/projects/clozapine/feature_generation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def main(
if generate_in_chunks:
flattened_df = ChunkedFeatureGenerator.create_flattened_dataset_with_chunking(
project_info=project_info,
eligible_prediction_times=ClozapineCohortDefiner.get_filtered_prediction_times_bundle().prediction_times.to_pandas(),
eligible_prediction_times=ClozapineCohortDefiner.get_filtered_prediction_times_bundle().prediction_times.frame.to_pandas(),
feature_specs=feature_specs, # type: ignore
chunksize=chunksize,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _get_outcome_specs(self) -> list[OutcomeSpec]:
return OutcomeGroupSpec(
named_dataframes=[
NamedDataframe(
df=ClozapineCohortDefiner.get_outcome_timestamps().to_pandas(),
df=ClozapineCohortDefiner.get_outcome_timestamps().frame.to_pandas(),
name="first_clozapine_prescription",
),
],
Expand All @@ -176,7 +176,7 @@ def _get_outcome_timestamp_specs(self) -> list[OutcomeSpec]:
return OutcomeGroupSpec(
named_dataframes=[
NamedDataframe(
df=ClozapineCohortDefiner.get_outcome_timestamps().to_pandas(),
df=ClozapineCohortDefiner.get_outcome_timestamps().frame.to_pandas(),
name="first_clozapine_prescription",
),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from psycop.common.cohort_definition import (
CohortDefiner,
FilteredPredictionTimeBundle,
OutcomeTimestampFrame,
filter_prediction_times,
)
from psycop.common.feature_generation.loaders.raw.load_visits import (
Expand Down Expand Up @@ -45,11 +46,13 @@ def get_filtered_prediction_times_bundle() -> FilteredPredictionTimeBundle:
return result

@staticmethod
def get_outcome_timestamps() -> pl.DataFrame:
return (
pl.from_pandas(get_first_cvd_indicator())
.with_columns(value=pl.lit(1))
.select(["dw_ek_borger", "timestamp", "value"])
def get_outcome_timestamps() -> OutcomeTimestampFrame:
return OutcomeTimestampFrame(
frame=(
pl.from_pandas(get_first_cvd_indicator())
.with_columns(value=pl.lit(1))
.select(["dw_ek_borger", "timestamp", "value"])
),
)


Expand Down
2 changes: 1 addition & 1 deletion psycop/projects/cvd/feature_generation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_cvd_project_info() -> ProjectInfo:
if __name__ == "__main__":
project_info = get_cvd_project_info()
eligible_prediction_times = (
CVDCohortDefiner.get_filtered_prediction_times_bundle().prediction_times.to_pandas()
CVDCohortDefiner.get_filtered_prediction_times_bundle().prediction_times.frame.to_pandas()
)
feature_specs = CVDFeatureSpecifier().get_feature_specs(layer=3)

Expand Down
2 changes: 1 addition & 1 deletion psycop/projects/cvd/feature_generation/specify_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _get_outcome_specs(self) -> list[OutcomeSpec]:
return OutcomeGroupSpec(
named_dataframes=[
NamedDataframe(
df=CVDCohortDefiner.get_outcome_timestamps().to_pandas(),
df=CVDCohortDefiner.get_outcome_timestamps().frame.to_pandas(),
name="score2_cvd",
),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from psycop.common.cohort_definition import (
CohortDefiner,
FilteredPredictionTimeBundle,
OutcomeTimestampFrame,
filter_prediction_times,
)
from psycop.projects.forced_admission_inpatient.cohort.extract_admissions_and_visits.get_forced_admissions import (
Expand Down Expand Up @@ -50,8 +51,10 @@ def get_filtered_prediction_times_bundle(
)

@staticmethod
def get_outcome_timestamps() -> pl.DataFrame:
return pl.from_pandas(forced_admissions_onset_timestamps())
def get_outcome_timestamps() -> OutcomeTimestampFrame:
return OutcomeTimestampFrame(
frame=pl.from_pandas(forced_admissions_onset_timestamps()),
)


if __name__ == "__main__":
Expand All @@ -65,8 +68,8 @@ def get_outcome_timestamps() -> pl.DataFrame:
)
)

df = bundle.prediction_times.to_pandas()
df = bundle.prediction_times.frame.to_pandas()

df_no_washout = bundle_no_washout.prediction_times.to_pandas()
df_no_washout = bundle_no_washout.prediction_times.frame.to_pandas()

outcome_timestamps = ForcedAdmissionsInpatientCohortDefiner.get_outcome_timestamps()
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main(
project_info=project_info,
eligible_prediction_times=ForcedAdmissionsInpatientCohortDefiner.get_filtered_prediction_times_bundle(
washout_on_prior_forced_admissions=washout_on_prior_forced_admissions,
).prediction_times.to_pandas(),
).prediction_times.frame.to_pandas(),
feature_specs=feature_specs, # type: ignore
chunksize=chunksize,
)
Expand All @@ -110,7 +110,7 @@ def main(
feature_specs=feature_specs, # type: ignore
prediction_times_df=ForcedAdmissionsInpatientCohortDefiner.get_filtered_prediction_times_bundle(
washout_on_prior_forced_admissions=washout_on_prior_forced_admissions,
).prediction_times.to_pandas(),
).prediction_times.frame.to_pandas(),
drop_pred_times_with_insufficient_look_distance=False,
project_info=project_info,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
if __name__ == "__main__":
pred_times = SczBpCohort.get_filtered_prediction_times_bundle().prediction_times

outcome_timestamps = SczBpCohort.get_outcome_timestamps().lazy()
outcome_timestamps = SczBpCohort.get_outcome_timestamps().frame.lazy()
outcome_with_age = SczBpAddAge().apply(outcome_timestamps)

first_eligible_outcome = (
pred_times.join(
pred_times.frame.join(
outcome_with_age.collect(),
how="left",
on="dw_ek_borger",
Expand Down
Loading