diff --git a/q2_ritme/feature_space/_process_train.py b/q2_ritme/feature_space/_process_train.py index 2a30cd2..f20511f 100644 --- a/q2_ritme/feature_space/_process_train.py +++ b/q2_ritme/feature_space/_process_train.py @@ -50,13 +50,19 @@ def process_train(config, train_val, target, host_id, tax, seed_data): # SPLIT # todo: refine assignment of features to be used for modelling train, val = split_data_by_host(train_val_t, host_id, 0.8, seed_data) - X_train, y_train = train[microbial_ft_ls_transf], train[target] - X_val, y_val = val[microbial_ft_ls_transf], val[target] + X_train, y_train, idx_train = ( + train[microbial_ft_ls_transf], + train[target], + train.index, + ) + X_val, y_val, idx_val = val[microbial_ft_ls_transf], val[target], val.index return ( X_train.values, y_train.values, + idx_train, X_val.values, y_val.values, + idx_val, microbial_ft_ls_transf, ) diff --git a/q2_ritme/model_space/static_trainables.py b/q2_ritme/model_space/static_trainables.py index 13143e3..998d779 100644 --- a/q2_ritme/model_space/static_trainables.py +++ b/q2_ritme/model_space/static_trainables.py @@ -3,7 +3,7 @@ import os import pickle import random -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple import joblib import numpy as np @@ -16,7 +16,7 @@ from classo import Classo from coral_pytorch.dataset import corn_label_from_logits from coral_pytorch.losses import corn_loss -from lightning import LightningModule, Trainer, seed_everything +from lightning import Callback, LightningModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint from ray import train from ray.air import session @@ -118,6 +118,49 @@ def _report_results_manually( return None +def _predict_from_engineered_x(model, model_type, X): + """Use model of model_type to create predictions from engineered X""" + if isinstance(model, NeuralNet): + with torch.no_grad(): + X_t = torch.tensor(X, dtype=torch.float32) + predicted = model(X_t) + predicted = model._prepare_predictions(predicted).values + elif isinstance(model, dict): + # trac model + log_geom, _ = _preprocess_taxonomy_aggregation(X, model["matrix_a"].values) + alpha = model["model"].values + predicted = log_geom.dot(alpha[1:]) + alpha[0] + predicted = predicted.flatten() + elif isinstance(model, xgb.core.Booster): + X_t = xgb.DMatrix(X) + predicted = model.predict(X_t).flatten() + else: + predicted = model.predict(X).flatten() + return predicted + + +def get_n_save_predictions( + model, model_type, X_train, y_train, idx_train, X_val, y_val, idx_val +): + split_dic = { + "train": (X_train, y_train, idx_train), + "val": (X_val, y_val, idx_val), + } + pred_ls = [] + for split, data in split_dic.items(): + X, y, idx = data + y_pred = _predict_from_engineered_x(model, model_type, X) + pred_df = pd.DataFrame({"true": y, "pred": y_pred}, index=idx) + pred_df["split"] = split + pred_ls.append(pred_df) + all_pred = pd.concat(pred_ls) + trial_path = train.get_context().get_trial_dir() + # todo: once you removed the former predictions -> rename to no suffix + path2save = os.path.join(trial_path, "debug_last_log_vs_preds.csv") + all_pred.to_csv(path2save, index=True) + return all_pred + + def train_linreg( config: Dict[str, Any], train_val: pd.DataFrame, @@ -143,10 +186,11 @@ def train_linreg( None """ # ! process dataset: X with features & y with host_id - X_train, y_train, X_val, y_val, ft_col = process_train( + # todo: maybe group X,y,idx into pandas? + X_train, y_train, idx_train, X_val, y_val, idx_val, ft_col = process_train( config, train_val, target, host_id, tax, seed_data ) - + # todo: add X_test, y_test here - with inferred feature engineering à la TunedModel # ! model np.random.seed(seed_model) linreg = ElasticNet( @@ -156,6 +200,11 @@ def train_linreg( ) linreg.fit(X_train, y_train) + # ! save predictions + _ = get_n_save_predictions( + linreg, "linreg", X_train, y_train, idx_train, X_val, y_val, idx_val + ) + _report_results_manually(linreg, X_train, y_train, X_val, y_val, tax) @@ -230,7 +279,7 @@ def train_trac( None """ # ! process dataset: X with features & y with host_id - X_train, y_train, X_val, y_val, ft_col = process_train( + X_train, y_train, idx_train, X_val, y_val, idx_val, ft_col = process_train( config, train_val, target, host_id, tax, seed_data ) # ! derive matrix A @@ -257,7 +306,11 @@ def train_trac( matrices_train, selected_param, intercept=intercept ) + # ! save predictions model = _bundle_trac_model(alpha, a_df) + _ = get_n_save_predictions( + model, "trac", X_train, y_train, idx_train, X_val, y_val, idx_val + ) _report_results_manually_trac( model, log_geom_train, y_train, log_geom_val, y_val, tax @@ -289,7 +342,7 @@ def train_rf( None """ # ! process dataset - X_train, y_train, X_val, y_val, ft_col = process_train( + X_train, y_train, idx_train, X_val, y_val, idx_val, ft_col = process_train( config, train_val, target, host_id, tax, seed_data ) @@ -307,11 +360,14 @@ def train_rf( ) rf.fit(X_train, y_train) + # ! save predictions + _ = get_n_save_predictions( + rf, "rf", X_train, y_train, idx_train, X_val, y_val, idx_val + ) _report_results_manually(rf, X_train, y_train, X_val, y_val, tax) class NeuralNet(LightningModule): - # TODO: adjust to have option of NNcorn also within def __init__(self, n_units, learning_rate, nn_type="regression"): super(NeuralNet, self).__init__() self.save_hyperparameters() # This saves all passed arguments to self.hparams @@ -443,32 +499,43 @@ def load_data(X_train, y_train, X_val, y_val, config): return train_loader, val_loader -class NNTuneReportCheckpointCallback(TuneReportCheckpointCallback): - def __init__( - self, - metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, - filename: str = "checkpoint", - save_checkpoints: bool = True, - on: Union[str, List[str]] = "validation_end", - nb_features: int = None, - ): - super().__init__( - metrics=metrics, filename=filename, save_checkpoints=save_checkpoints, on=on +class PostTrainingCallback(Callback): + def __init__(self, nn_type, X_train, y_train, idx_train, X_val, y_val, idx_val): + super().__init__() + self.nn_type = nn_type + self.X_train = X_train + self.y_train = y_train + self.idx_train = idx_train + self.X_val = X_val + self.y_val = y_val + self.idx_val = idx_val + + def on_validation_epoch_end(self, trainer, pl_module): + # Your post-training logic here + _ = get_n_save_predictions( + pl_module, + self.nn_type, + self.X_train, + self.y_train, + self.idx_train, + self.X_val, + self.y_val, + self.idx_val, ) - self.nb_features = nb_features - def _handle(self, trainer: Trainer, pl_module: LightningModule): - # CUSTOM: includes also nb_features in report - if trainer.sanity_checking: - return - report_dict = self._get_report_dict(trainer, pl_module) - report_dict["nb_features"] = self.nb_features - if not report_dict: - return +class CustomTuneReportCallback(TuneReportCheckpointCallback): + def __init__(self, *args, post_training_callback=None, **kwargs): + super().__init__(*args, **kwargs) + self.post_training_callback = post_training_callback - with self._get_checkpoint(trainer) as checkpoint: - train.report(report_dict, checkpoint=checkpoint) + def on_validation_epoch_end(self, trainer, pl_module): + # this ensures that the predictions are saved before + # TuneReportCheckpointCallback is called and tune is getting the signal + # to stop the trial + if self.post_training_callback: + self.post_training_callback.on_validation_epoch_end(trainer, pl_module) + super().on_validation_epoch_end(trainer, pl_module) def train_nn( @@ -485,7 +552,7 @@ def train_nn( seed_everything(seed_model, workers=True) # Process dataset - X_train, y_train, X_val, y_val, ft_col = process_train( + X_train, y_train, idx_train, X_val, y_val, idx_val, ft_col = process_train( config, train_val, target, host_id, tax, seed_data ) @@ -522,6 +589,15 @@ def train_nn( os.makedirs(checkpoint_dir, exist_ok=True) + post_training_callback = PostTrainingCallback( + nn_type=nn_type, + X_train=X_train, + y_train=y_train, + idx_train=idx_train, + X_val=X_val, + y_val=y_val, + idx_val=idx_val, + ) callbacks = [ ModelCheckpoint( monitor="val_rmse", @@ -531,7 +607,9 @@ def train_nn( dirpath=checkpoint_dir, # Automatically set dirpath filename="{epoch}-{val_rmse:.2f}", ), - NNTuneReportCheckpointCallback( + # the below callback signals to ray tune that the trainable is finished + # - hence post_training_callback must be set to store predictions + CustomTuneReportCallback( metrics={ "rmse_val": "val_rmse", "rmse_train": "train_rmse", @@ -539,10 +617,12 @@ def train_nn( "r2_train": "train_r2", "loss_val": "val_loss", "loss_train": "train_loss", + # "nb_features": "nb_features", }, filename="checkpoint", on="validation_end", - nb_features=X_train.shape[1], + save_checkpoints=True, + post_training_callback=post_training_callback, ), ] @@ -652,7 +732,7 @@ def train_xgb( None """ # ! process dataset - X_train, y_train, X_val, y_val, ft_col = process_train( + X_train, y_train, idx_train, X_val, y_val, idx_val, ft_col = process_train( config, train_val, target, host_id, tax, seed_data ) # Set seeds @@ -682,12 +762,14 @@ def train_xgb( ) # todo: add test set here to be tracked as well - xgb.train( + xgb_model = xgb.train( config, dtrain, evals=[(dtrain, "train"), (dval, "val")], callbacks=[checkpoint_callback], custom_metric=custom_xgb_metric, ) - - # TODO: add test set here to be tracked as well + # ! save predictions + _ = get_n_save_predictions( + xgb_model, "xgb", X_train, y_train, idx_train, X_val, y_val, idx_val + )