-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
316 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Dict, Type | ||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
from omegaconf import DictConfig | ||
from mixins import TimeableMixin | ||
|
||
|
||
class BaseModel(ABC, TimeableMixin): | ||
@abstractmethod | ||
def __init__(self): | ||
pass | ||
|
||
@abstractmethod | ||
def train(self): | ||
pass | ||
|
||
@abstractmethod | ||
def evaluate(self) -> float: | ||
pass | ||
|
||
@abstractmethod | ||
def save_model(self, output_fp: Path): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
defaults: | ||
- default | ||
- tabularization: default | ||
- model: xgboost # This can be changed to sgd_classifier or any other model | ||
- override hydra/sweeper: optuna | ||
- override hydra/sweeper/sampler: tpe | ||
- override hydra/launcher: joblib | ||
- _self_ | ||
|
||
task_name: task | ||
|
||
# Task cached data dir | ||
input_dir: ${output_cohort_dir}/${task_name}/task_cache | ||
# Directory with task labels | ||
input_label_dir: ${output_cohort_dir}/${task_name}/labels/ | ||
# Where to output the model and cached data | ||
model_dir: ${output_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S} | ||
output_filepath: ${model_dir}/model_metadata.json | ||
|
||
log_dir: ${model_dir}/.logs/ | ||
|
||
name: launch_model |
19 changes: 19 additions & 0 deletions
19
src/MEDS_tabular_automl/configs/models/sgd_classifier.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
model_params: | ||
epochs: 20 | ||
early_stopping_rounds: 5 | ||
model: | ||
type: sklearn | ||
_target_: sklearn.linear_model.SGDClassifier | ||
loss: log_loss | ||
iterator: | ||
keep_data_in_memory: True | ||
binarize_task: True | ||
|
||
hydra: | ||
sweeper: | ||
params: | ||
+model_params.model.alpha: tag(log, interval(1e-6, 1)) | ||
+model_params.model.l1_ratio: interval(0, 1) | ||
+model_params.model.penalty: choice(['l1', 'l2', 'elasticnet']) | ||
model_params.epochs: range(10, 100) | ||
model_params.early_stopping_rounds: range(1, 10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
model_params: | ||
num_boost_round: 1000 | ||
early_stopping_rounds: 5 | ||
model: | ||
type: xgboost | ||
# _target_: xgboost.XGBClassifier | ||
booster: gbtree | ||
device: cpu | ||
nthread: 1 | ||
tree_method: hist | ||
objective: binary:logistic | ||
iterator: | ||
keep_data_in_memory: True | ||
binarize_task: True | ||
|
||
hydra: | ||
sweeper: | ||
params: | ||
+model_params.model.eta: tag(log, interval(0.001, 1)) | ||
+model_params.model.lambda: tag(log, interval(0.001, 1)) | ||
+model_params.model.alpha: tag(log, interval(0.001, 1)) | ||
+model_params.model.subsample: interval(0.5, 1) | ||
+model_params.model.min_child_weight: interval(1e-2, 100) | ||
model_params.num_boost_round: range(100, 1000) | ||
model_params.early_stopping_rounds: range(1, 10) | ||
+model_params.model.max_depth: range(2, 16) | ||
tabularization.min_code_inclusion_frequency: tag(log, range(10, 1000000)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from pathlib import Path | ||
|
||
import hydra | ||
import numpy as np | ||
import scipy.sparse as sp | ||
from loguru import logger | ||
from mixins import TimeableMixin | ||
from omegaconf import DictConfig | ||
from sklearn.metrics import roc_auc_score | ||
|
||
from .tabular_dataset import TabularDataset | ||
from .base_model import BaseModel | ||
|
||
|
||
class DenseIterator(TabularDataset, TimeableMixin): | ||
|
||
def __init__(self, cfg: DictConfig, split: str): | ||
"""Initializes the SklearnIterator with the provided configuration and data split. | ||
Args: | ||
cfg: The configuration dictionary. | ||
split: The data split to use. | ||
""" | ||
TabularDataset.__init__(self, cfg=cfg, split=split) | ||
TimeableMixin.__init__(self) | ||
self.valid_event_ids, self.labels = self._load_ids_and_labels() | ||
# check if the labels are empty | ||
if len(self.labels) == 0: | ||
raise ValueError("No labels found.") | ||
# self._it = 0 | ||
|
||
def densify(self) -> np.ndarray: | ||
"""Builds the data as a dense matrix based on column subselection.""" | ||
|
||
# get the column indices to include | ||
cols = self.get_feature_indices() | ||
|
||
# map those to the feature names in the data | ||
feature_names = self.get_all_column_names() | ||
selected_features = [feature_names[col] for col in cols] | ||
|
||
# get the dense matrix by iterating through the data shards | ||
data = [] | ||
labels = [] | ||
for shard_idx in range(len(self._data_shards)): | ||
shard_data, shard_labels = self.get_data_shards(shard_idx) | ||
shard_data = shard_data[:, cols] | ||
data.append(shard_data) | ||
labels.append(shard_labels) | ||
data = sp.vstack(data) | ||
labels = np.concatenate(labels, axis=0) | ||
return data, labels, selected_features | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from importlib.resources import files | ||
from pathlib import Path | ||
|
||
import hydra | ||
import pandas as pd | ||
from loguru import logger | ||
from omegaconf import DictConfig | ||
|
||
from MEDS_tabular_automl.dense_iterator import DenseIterator | ||
|
||
from ..utils import hydra_loguru_init | ||
|
||
config_yaml = files("MEDS_tabular_automl").joinpath("configs/launch_xgboost.yaml") | ||
if not config_yaml.is_file(): | ||
raise FileNotFoundError("Core configuration not successfully installed!") | ||
|
||
|
||
@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) | ||
def main(cfg: DictConfig) -> float: | ||
"""Launches AutoGluon after collecting data based on the provided configuration. | ||
Args: | ||
cfg: The configuration dictionary specifying model and training parameters. | ||
""" | ||
|
||
# print(OmegaConf.to_yaml(cfg)) | ||
if not cfg.loguru_init: | ||
hydra_loguru_init() | ||
|
||
# check that autogluon is installed | ||
try: | ||
import autogluon as ag | ||
except ImportError: | ||
logger.error("AutoGluon is not installed. Please install AutoGluon.") | ||
|
||
# collect data based on the configuration | ||
itrain = DenseIterator(cfg, "train") | ||
ituning = DenseIterator(cfg, "tuning") | ||
iheld_out = DenseIterator(cfg, "held_out") | ||
|
||
# collect data for AutoGluon | ||
train_data, train_labels, cols = itrain.densify() | ||
tuning_data, tuning_labels, _ = ituning.densify() | ||
held_out_data, held_out_labels, _ = iheld_out.densify() | ||
|
||
# construct dfs for AutoGluon | ||
train_df = pd.DataFrame(train_data.todense(), columns=cols) | ||
train_df[cfg.task_name] = train_labels | ||
tuning_df = pd.DataFrame(tuning_data.todense(), columns=cols) | ||
tuning_df[cfg.task_name] = tuning_labels | ||
held_out_df = pd.DataFrame(held_out_data.todense(), columns=cols) | ||
held_out_df[cfg.task_name] = held_out_labels | ||
|
||
# launch AutoGluon | ||
predictor = ag.TabularPredictor(label=cfg.task_name).fit(train_data=train_df, tuning_data=tuning_df) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from importlib.resources import files | ||
from pathlib import Path | ||
|
||
import hydra | ||
from loguru import logger | ||
from omegaconf import DictConfig | ||
from typing import Dict, Type | ||
|
||
from MEDS_tabular_automl.base_model import BaseModel | ||
from MEDS_tabular_automl.sklearn_model import SklearnModel | ||
from MEDS_tabular_automl.xgboost_model import XGBoostModel | ||
|
||
|
||
MODEL_CLASSES: Dict[str, Type[BaseModel]] = { | ||
"xgboost": XGBoostModel, | ||
"sklearn": SklearnModel | ||
} | ||
|
||
from ..utils import hydra_loguru_init | ||
|
||
config_yaml = files("MEDS_tabular_automl").joinpath("configs/launch_xgboost.yaml") | ||
if not config_yaml.is_file(): | ||
raise FileNotFoundError("Core configuration not successfully installed!") | ||
|
||
|
||
@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) | ||
def main(cfg: DictConfig) -> float: | ||
"""Optimizes the model based on the provided configuration. | ||
Args: | ||
cfg: The configuration dictionary specifying model and training parameters. | ||
Returns: | ||
The evaluation result as the ROC AUC score on the held-out test set. | ||
""" | ||
|
||
# print(OmegaConf.to_yaml(cfg)) | ||
if not cfg.loguru_init: | ||
hydra_loguru_init() | ||
try: | ||
model_type = cfg.model.type | ||
ModelClass = MODEL_CLASSES.get(model_type) | ||
if ModelClass is None: | ||
raise ValueError(f"Model type {model_type} not supported.") | ||
|
||
model = ModelClass(cfg) | ||
model.train() | ||
auc = model.evaluate() | ||
logger.info(f"AUC: {auc}") | ||
|
||
# save model | ||
output_fp = Path(cfg.output_filepath) | ||
output_fp.parent.mkdir(parents=True, exist_ok=True) | ||
|
||
model.save_model(output_fp) | ||
except Exception as e: | ||
logger.error(f"Error occurred: {e}") | ||
auc = 0.0 | ||
return auc | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,7 +108,6 @@ def load_tqdm(use_tqdm: bool): | |
|
||
return tqdm | ||
else: | ||
|
||
def noop(x, **kwargs): | ||
return x | ||
|
||
|
Oops, something went wrong.