diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 29be89f9..f5505fdb 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -70,7 +70,7 @@ def __post_init__(self): # Defining device self.device = torch.device(define_device(self.device_name)) if self.verbose != 0: - print(f"Device used : {self.device}") + warnings.warn(f"Device used : {self.device}") def __update__(self, **kwargs): """ @@ -210,9 +210,8 @@ def fit( self._set_callbacks(callbacks) if from_unsupervised is not None: - print("Loading weights from unsupervised pretraining") self.load_weights_from_unsupervised(from_unsupervised) - + warnings.warn("Loading weights from unsupervised pretraining") # Call method on_train_begin for all callbacks self._callback_container.on_train_begin() @@ -622,9 +621,9 @@ def _set_callbacks(self, custom_callbacks): ) callbacks.append(early_stopping) else: - print( - "No early stopping will be performed, last training weights will be used." - ) + wrn_msg = "No early stopping will be performed, last training weights will be used." + warnings.warn(wrn_msg) + if self.scheduler_fn is not None: # Add LR Scheduler call_back is_batch_level = self.scheduler_params.pop("is_batch_level", False) diff --git a/pytorch_tabnet/callbacks.py b/pytorch_tabnet/callbacks.py index e521c496..cb031d54 100644 --- a/pytorch_tabnet/callbacks.py +++ b/pytorch_tabnet/callbacks.py @@ -4,6 +4,7 @@ import numpy as np from dataclasses import dataclass, field from typing import List, Any +import warnings class Callback: @@ -167,7 +168,8 @@ def on_train_end(self, logs=None): + f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}" ) print(msg) - print("Best weights from best epoch are automatically used!") + wrn_msg = "Best weights from best epoch are automatically used!" + warnings.warn(wrn_msg) @dataclass