diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index 8f4fbc80..aacea204 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -37,7 +37,6 @@ def make_model(self, activation: nn.Module = nn.LeakyReLU()): is_model : bool True because an MLP is a model. """ - activation.to(self.device) if self.shared_weights == True and self.base is not None: mlp = nn.Sequential( @@ -66,7 +65,7 @@ def make_model(self, activation: nn.Module = nn.LeakyReLU()): + self.tail ) ) - return mlp, True + return mlp.to(self.device), True else: raise ValueError( "Base Model must be provided when shared_weights is set to True"