-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
210 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 2 additions & 2 deletions
4
...refix_count_validator_20231117_092246.cfg → ...fix_count_expectation_20231117_135230.cfg
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
|
||
# Example cfg for prefix_count_validator | ||
# Example cfg for column_prefix_count_expectation | ||
# You can find args at: | ||
# psycop.common.model_training_v2.trainer.preprocessing.steps.column_validator | ||
[placeholder] | ||
test = ["prefix_", 2] | ||
placeholder = ["pred", 1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
[project_info] | ||
experiment_path = /. | ||
|
||
[logger] | ||
@loggers = "terminal_logger" | ||
|
||
[trainer] | ||
@trainers = "crossval_trainer" | ||
outcome_col_name = "outc_score2_cvd_within_1825_days_maximum_fallback_0_dichotomous" | ||
n_splits = 5 | ||
group_col_name = "dw_ek_borger" | ||
|
||
[trainer.metric] | ||
@metrics = "binary_auroc" | ||
|
||
[trainer.training_data] | ||
@data = "parquet_vertical_concatenator" | ||
paths = ["E:/shared_resources/cvd/e2e_base_test/flattened_datasets/train.parquet", "E:/shared_resources/cvd/e2e_base_test/flattened_datasets/val.parquet"] | ||
|
||
[trainer.logger] | ||
@loggers = "terminal_logger" | ||
|
||
################# | ||
# Preprocessing # | ||
################# | ||
[trainer.preprocessing_pipeline] | ||
@preprocessing = "baseline_preprocessing_pipeline" | ||
|
||
[trainer.preprocessing_pipeline.*.bool_to_int] | ||
@preprocessing = "bool_to_int" | ||
|
||
[trainer.preprocessing_pipeline.*.columns_exist] | ||
@preprocessing = "column_exists_validator" | ||
|
||
[trainer.preprocessing_pipeline.*.columns_exist.*] | ||
age = "pred_age_in_years" | ||
pred_time_uuid = "prediction_time_uuid" | ||
|
||
[trainer.preprocessing_pipeline.*.regex_column_blacklist] | ||
@preprocessing = "regex_column_blacklist" | ||
|
||
[trainer.preprocessing_pipeline.*.regex_column_blacklist.*] | ||
outcome = "outc_.+(365|1095).*" | ||
|
||
[trainer.preprocessing_pipeline.*.temporal_col_filter] | ||
@preprocessing = "temporal_col_filter" | ||
|
||
[trainer.preprocessing_pipeline.*.column_prefix_count_expectation] | ||
@preprocessing = "column_prefix_count_expectation" | ||
|
||
[trainer.preprocessing_pipeline.*.column_prefix_count_expectation.*] | ||
outcome_prefix = ["outc_", 1] | ||
|
||
[trainer.preprocessing_pipeline.*.age_filter] | ||
@preprocessing = "age_filter" | ||
min_age = 0 | ||
max_age = 99 | ||
age_col_name = ${trainer.preprocessing_pipeline.*.columns_exist.*.age} | ||
|
||
######## | ||
# Task # | ||
######## | ||
[trainer.task] | ||
@tasks = "binary_classification" | ||
pred_time_uuid_col_name = ${trainer.preprocessing_pipeline.*.columns_exist.*.pred_time_uuid} | ||
|
||
[trainer.task.task_pipe] | ||
@task_pipelines = "binary_classification_pipeline" | ||
|
||
[trainer.task.task_pipe.sklearn_pipe] | ||
@task_pipelines = "pipe_constructor" | ||
|
||
[trainer.task.task_pipe.sklearn_pipe.*.model] | ||
@estimator_steps = "xgboost" | ||
|
40 changes: 40 additions & 0 deletions
40
psycop/projects/cvd/model_training/data_loader/trainval_loader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from collections.abc import Sequence | ||
from pathlib import Path | ||
|
||
import polars as pl | ||
from functionalpy import Seq | ||
|
||
from psycop.common.model_training_v2.config.baseline_registry import BaselineRegistry | ||
from psycop.common.model_training_v2.trainer.base_dataloader import BaselineDataLoader | ||
|
||
|
||
class MissingPathError(Exception): | ||
... | ||
|
||
|
||
@BaselineRegistry.data.register("parquet_vertical_concatenator") | ||
class ParquetVerticalConcatenator(BaselineDataLoader): | ||
def __init__(self, paths: Sequence[str]): | ||
self.dataset_paths = [Path(arg) for arg in paths] | ||
|
||
missing_paths = ( | ||
Seq(self.dataset_paths).map(self._check_path_exists).flatten().to_list() | ||
) | ||
if missing_paths: | ||
raise MissingPathError( | ||
f"""The following paths are missing: | ||
{missing_paths} | ||
""", | ||
) | ||
|
||
def _check_path_exists(self, path: Path) -> list[MissingPathError]: | ||
if not path.exists(): | ||
return [MissingPathError(path)] | ||
|
||
return [] | ||
|
||
def load(self) -> pl.LazyFrame: | ||
return pl.concat( | ||
how="vertical", | ||
items=[pl.scan_parquet(path) for path in self.dataset_paths], | ||
) |
17 changes: 17 additions & 0 deletions
17
psycop/projects/cvd/model_training/populate_cvd_registry.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# ruff: noqa | ||
|
||
|
||
def populate_with_cvd_registry() -> None: | ||
from psycop.projects.cvd.model_training.data_loader.trainval_loader import ( | ||
ParquetVerticalConcatenator, | ||
) | ||
from psycop.projects.cvd.model_training.preprocessing.regex_filter import ( | ||
RegexColumnBlacklist, | ||
) | ||
from psycop.projects.cvd.model_training.preprocessing.datetime_filter import ( | ||
TemporalColumnFilter, | ||
) | ||
from psycop.projects.cvd.model_training.preprocessing.bool_to_int import BoolToInt | ||
|
||
|
||
populate_with_cvd_registry() |
21 changes: 21 additions & 0 deletions
21
psycop/projects/cvd/model_training/preprocessing/bool_to_int.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import polars as pl | ||
from polars import Boolean | ||
|
||
from psycop.common.model_training_v2.config.baseline_registry import BaselineRegistry | ||
from psycop.common.model_training_v2.trainer.preprocessing.step import ( | ||
PolarsFrame_T0, | ||
PresplitStep, | ||
) | ||
|
||
|
||
@BaselineRegistry.preprocessing.register("bool_to_int") | ||
class BoolToInt(PresplitStep): | ||
def __init__(self): | ||
pass | ||
|
||
def apply(self, input_df: PolarsFrame_T0) -> PolarsFrame_T0: | ||
for col_name in input_df.columns: | ||
if input_df.schema[col_name] == Boolean: # type: ignore | ||
input_df = input_df.with_columns(pl.col(col_name).cast(int)) | ||
|
||
return input_df |
17 changes: 17 additions & 0 deletions
17
psycop/projects/cvd/model_training/preprocessing/datetime_filter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import polars.selectors as cs | ||
|
||
from psycop.common.model_training_v2.config.baseline_registry import BaselineRegistry | ||
from psycop.common.model_training_v2.trainer.preprocessing.step import ( | ||
PolarsFrame_T0, | ||
PresplitStep, | ||
) | ||
|
||
|
||
@BaselineRegistry.preprocessing.register("temporal_col_filter") | ||
class TemporalColumnFilter(PresplitStep): | ||
def __init__(self): | ||
pass | ||
|
||
def apply(self, input_df: PolarsFrame_T0) -> PolarsFrame_T0: | ||
temporal_columns = input_df.select(cs.temporal()).columns | ||
return input_df.drop(temporal_columns) |
19 changes: 19 additions & 0 deletions
19
psycop/projects/cvd/model_training/preprocessing/regex_filter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import polars as pl | ||
|
||
from psycop.common.model_training_v2.config.baseline_registry import BaselineRegistry | ||
from psycop.common.model_training_v2.trainer.preprocessing.step import ( | ||
PolarsFrame_T0, | ||
PresplitStep, | ||
) | ||
|
||
|
||
@BaselineRegistry.preprocessing.register("regex_column_blacklist") | ||
class RegexColumnBlacklist(PresplitStep): | ||
def __init__(self, *args: str): | ||
self.regex_blacklist = args | ||
|
||
def apply(self, input_df: PolarsFrame_T0) -> PolarsFrame_T0: | ||
for blacklist in self.regex_blacklist: | ||
input_df = input_df.select(pl.exclude(f"^{blacklist}$")) | ||
|
||
return input_df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from pathlib import Path | ||
|
||
from psycop.common.model_training_v2.config.baseline_pipeline import ( | ||
train_baseline_model, | ||
) | ||
from psycop.common.model_training_v2.config.config_utils import ( | ||
load_baseline_config, | ||
) | ||
from psycop.projects.cvd.model_training.populate_cvd_registry import ( | ||
populate_with_cvd_registry, | ||
) | ||
|
||
if __name__ == "__main__": | ||
populate_with_cvd_registry() | ||
config = load_baseline_config(Path(__file__).parent / "cvd_baseline.cfg") | ||
train_baseline_model(config) |