Skip to content

Commit

Permalink
feat: print the model's name when logging the number of model paramet…
Browse files Browse the repository at this point in the history
…ers;
  • Loading branch information
WenjieDu committed Dec 20, 2023
1 parent 15afa72 commit 88056a4
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,8 @@ def _print_model_size(self) -> None:
"""Print the number of trainable parameters in the initialized NN model."""
num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
logger.info(
f"Model initialized successfully with the number of trainable parameters: {num_params:,}"
f"A {self.__class__.__name__} model initialized with the given hyperparameters, "
f"the number of trainable parameters: {num_params:,}"
)

@abstractmethod
Expand Down

0 comments on commit 88056a4

Please sign in to comment.