Skip to content

Commit

Permalink
feat-312: replace prints by warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Nov 5, 2021
1 parent 709fcb1 commit b80eeec
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
11 changes: 5 additions & 6 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __post_init__(self):
# Defining device
self.device = torch.device(define_device(self.device_name))
if self.verbose != 0:
print(f"Device used : {self.device}")
warnings.warn(f"Device used : {self.device}")

def __update__(self, **kwargs):
"""
Expand Down Expand Up @@ -210,9 +210,8 @@ def fit(
self._set_callbacks(callbacks)

if from_unsupervised is not None:
print("Loading weights from unsupervised pretraining")
self.load_weights_from_unsupervised(from_unsupervised)

warnings.warn("Loading weights from unsupervised pretraining")
# Call method on_train_begin for all callbacks
self._callback_container.on_train_begin()

Expand Down Expand Up @@ -622,9 +621,9 @@ def _set_callbacks(self, custom_callbacks):
)
callbacks.append(early_stopping)
else:
print(
"No early stopping will be performed, last training weights will be used."
)
wrn_msg = "No early stopping will be performed, last training weights will be used."
warnings.warn(wrn_msg)

if self.scheduler_fn is not None:
# Add LR Scheduler call_back
is_batch_level = self.scheduler_params.pop("is_batch_level", False)
Expand Down
4 changes: 3 additions & 1 deletion pytorch_tabnet/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from dataclasses import dataclass, field
from typing import List, Any
import warnings


class Callback:
Expand Down Expand Up @@ -167,7 +168,8 @@ def on_train_end(self, logs=None):
+ f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}"
)
print(msg)
print("Best weights from best epoch are automatically used!")
wrn_msg = "Best weights from best epoch are automatically used!"
warnings.warn(wrn_msg)


@dataclass
Expand Down

0 comments on commit b80eeec

Please sign in to comment.