diff --git a/pypots/base.py b/pypots/base.py index 823bf06b..0d759646 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -543,7 +543,8 @@ def _print_model_size(self) -> None: """Print the number of trainable parameters in the initialized NN model.""" num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) logger.info( - f"Model initialized successfully with the number of trainable parameters: {num_params:,}" + f"A {self.__class__.__name__} model initialized with the given hyperparameters, " + f"the number of trainable parameters: {num_params:,}" ) @abstractmethod