Skip to content

Commit

Permalink
Fix setting model.criterio, using kwargs instead of hparams.
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlesGaydon committed Feb 6, 2024
1 parent 9387ddb commit 9c6a1a6
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions myria3d/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def __init__(self, **kwargs):
# it also allows to access params with 'self.hparams' attribute
self.save_hyperparameters(ignore=["criterion"])

neural_net_class = get_neural_net_class(self.hparams.neural_net_class_name)
self.model = neural_net_class(**self.hparams.neural_net_hparams)
neural_net_class = get_neural_net_class(kwargs.get("neural_net_class_name"))
self.model = neural_net_class(**kwargs.get("neural_net_hparams"))

self.softmax = nn.Softmax(dim=1)
self.criterion = self.hparams.criterion
self.criterion = kwargs.get("criterion")

def on_fit_start(self) -> None:
self.criterion = self.criterion.to(self.device)
Expand Down

0 comments on commit 9c6a1a6

Please sign in to comment.