Skip to content

Commit

Permalink
feat: train cvd with v2 (#455)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff authored Nov 20, 2023
2 parents b204c4c + e673388 commit 6f66569
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 3 deletions.
2 changes: 2 additions & 0 deletions psycop/common/model_training_v2/config/baseline_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ def train_baseline_model(cfg: BaselineSchema) -> float:
cfg.logger.log_config(
cfg.dict(),
) # Dict handling, might have to be flattened depending on the logger. Probably want all loggers to take flattened dicts.
# TODO: Currently logs the resolved objects. We want to fix that.

result = cfg.trainer.train()
result.df.write_parquet(cfg.project_info.experiment_path / "eval_df.parquet")
# TODO: https://github.com/Aarhus-Psychiatry-Research/psycop-common/issues/447 Allow dynamic generation of experiments paths

return result.metric.value
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

# Example cfg for prefix_count_validator
# Example cfg for column_prefix_count_expectation
# You can find args at:
# psycop.common.model_training_v2.trainer.preprocessing.steps.column_validator
[placeholder]
test = ["prefix_", 2]
placeholder = ["pred", 1]
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def from_list(
return cls(prefix=prefix, count=count) # type: ignore


@BaselineRegistry.preprocessing.register("prefix_count_validator")
@BaselineRegistry.preprocessing.register("column_prefix_count_expectation")
class ColumnPrefixExpectation(PresplitStep):
def __init__(
self,
Expand Down
75 changes: 75 additions & 0 deletions psycop/projects/cvd/model_training/cvd_baseline.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
[project_info]
experiment_path = /.

[logger]
@loggers = "terminal_logger"

[trainer]
@trainers = "crossval_trainer"
outcome_col_name = "outc_score2_cvd_within_1825_days_maximum_fallback_0_dichotomous"
n_splits = 5
group_col_name = "dw_ek_borger"

[trainer.metric]
@metrics = "binary_auroc"

[trainer.training_data]
@data = "parquet_vertical_concatenator"
paths = ["E:/shared_resources/cvd/e2e_base_test/flattened_datasets/train.parquet", "E:/shared_resources/cvd/e2e_base_test/flattened_datasets/val.parquet"]

[trainer.logger]
@loggers = "terminal_logger"

#################
# Preprocessing #
#################
[trainer.preprocessing_pipeline]
@preprocessing = "baseline_preprocessing_pipeline"

[trainer.preprocessing_pipeline.*.bool_to_int]
@preprocessing = "bool_to_int"

[trainer.preprocessing_pipeline.*.columns_exist]
@preprocessing = "column_exists_validator"

[trainer.preprocessing_pipeline.*.columns_exist.*]
age = "pred_age_in_years"
pred_time_uuid = "prediction_time_uuid"

[trainer.preprocessing_pipeline.*.regex_column_blacklist]
@preprocessing = "regex_column_blacklist"

[trainer.preprocessing_pipeline.*.regex_column_blacklist.*]
outcome = "outc_.+(365|1095).*"

[trainer.preprocessing_pipeline.*.temporal_col_filter]
@preprocessing = "temporal_col_filter"

[trainer.preprocessing_pipeline.*.column_prefix_count_expectation]
@preprocessing = "column_prefix_count_expectation"

[trainer.preprocessing_pipeline.*.column_prefix_count_expectation.*]
outcome_prefix = ["outc_", 1]

[trainer.preprocessing_pipeline.*.age_filter]
@preprocessing = "age_filter"
min_age = 0
max_age = 99
age_col_name = ${trainer.preprocessing_pipeline.*.columns_exist.*.age}

########
# Task #
########
[trainer.task]
@tasks = "binary_classification"
pred_time_uuid_col_name = ${trainer.preprocessing_pipeline.*.columns_exist.*.pred_time_uuid}

[trainer.task.task_pipe]
@task_pipelines = "binary_classification_pipeline"

[trainer.task.task_pipe.sklearn_pipe]
@task_pipelines = "pipe_constructor"

[trainer.task.task_pipe.sklearn_pipe.*.model]
@estimator_steps = "xgboost"

40 changes: 40 additions & 0 deletions psycop/projects/cvd/model_training/data_loader/trainval_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from collections.abc import Sequence
from pathlib import Path

import polars as pl
from functionalpy import Seq

from psycop.common.model_training_v2.config.baseline_registry import BaselineRegistry
from psycop.common.model_training_v2.trainer.base_dataloader import BaselineDataLoader


class MissingPathError(Exception):
...


@BaselineRegistry.data.register("parquet_vertical_concatenator")
class ParquetVerticalConcatenator(BaselineDataLoader):
def __init__(self, paths: Sequence[str]):
self.dataset_paths = [Path(arg) for arg in paths]

missing_paths = (
Seq(self.dataset_paths).map(self._check_path_exists).flatten().to_list()
)
if missing_paths:
raise MissingPathError(
f"""The following paths are missing:
{missing_paths}
""",
)

def _check_path_exists(self, path: Path) -> list[MissingPathError]:
if not path.exists():
return [MissingPathError(path)]

return []

def load(self) -> pl.LazyFrame:
return pl.concat(
how="vertical",
items=[pl.scan_parquet(path) for path in self.dataset_paths],
)
17 changes: 17 additions & 0 deletions psycop/projects/cvd/model_training/populate_cvd_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# ruff: noqa


def populate_with_cvd_registry() -> None:
from psycop.projects.cvd.model_training.data_loader.trainval_loader import (
ParquetVerticalConcatenator,
)
from psycop.projects.cvd.model_training.preprocessing.regex_filter import (
RegexColumnBlacklist,
)
from psycop.projects.cvd.model_training.preprocessing.datetime_filter import (
TemporalColumnFilter,
)
from psycop.projects.cvd.model_training.preprocessing.bool_to_int import BoolToInt


populate_with_cvd_registry()
21 changes: 21 additions & 0 deletions psycop/projects/cvd/model_training/preprocessing/bool_to_int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import polars as pl
from polars import Boolean

from psycop.common.model_training_v2.config.baseline_registry import BaselineRegistry
from psycop.common.model_training_v2.trainer.preprocessing.step import (
PolarsFrame_T0,
PresplitStep,
)


@BaselineRegistry.preprocessing.register("bool_to_int")
class BoolToInt(PresplitStep):
def __init__(self):
pass

def apply(self, input_df: PolarsFrame_T0) -> PolarsFrame_T0:
for col_name in input_df.columns:
if input_df.schema[col_name] == Boolean: # type: ignore
input_df = input_df.with_columns(pl.col(col_name).cast(int))

return input_df
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import polars.selectors as cs

from psycop.common.model_training_v2.config.baseline_registry import BaselineRegistry
from psycop.common.model_training_v2.trainer.preprocessing.step import (
PolarsFrame_T0,
PresplitStep,
)


@BaselineRegistry.preprocessing.register("temporal_col_filter")
class TemporalColumnFilter(PresplitStep):
def __init__(self):
pass

def apply(self, input_df: PolarsFrame_T0) -> PolarsFrame_T0:
temporal_columns = input_df.select(cs.temporal()).columns
return input_df.drop(temporal_columns)
19 changes: 19 additions & 0 deletions psycop/projects/cvd/model_training/preprocessing/regex_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import polars as pl

from psycop.common.model_training_v2.config.baseline_registry import BaselineRegistry
from psycop.common.model_training_v2.trainer.preprocessing.step import (
PolarsFrame_T0,
PresplitStep,
)


@BaselineRegistry.preprocessing.register("regex_column_blacklist")
class RegexColumnBlacklist(PresplitStep):
def __init__(self, *args: str):
self.regex_blacklist = args

def apply(self, input_df: PolarsFrame_T0) -> PolarsFrame_T0:
for blacklist in self.regex_blacklist:
input_df = input_df.select(pl.exclude(f"^{blacklist}$"))

return input_df
16 changes: 16 additions & 0 deletions psycop/projects/cvd/model_training/train_model_e2e_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pathlib import Path

from psycop.common.model_training_v2.config.baseline_pipeline import (
train_baseline_model,
)
from psycop.common.model_training_v2.config.config_utils import (
load_baseline_config,
)
from psycop.projects.cvd.model_training.populate_cvd_registry import (
populate_with_cvd_registry,
)

if __name__ == "__main__":
populate_with_cvd_registry()
config = load_baseline_config(Path(__file__).parent / "cvd_baseline.cfg")
train_baseline_model(config)

0 comments on commit 6f66569

Please sign in to comment.