From da184f09b6d6a7c385f2cacba6f0aa12cde9057a Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Mon, 28 Oct 2024 22:10:50 -0400 Subject: [PATCH] small fix for last_layer hanging gradients with PAINN --- hydragnn/models/PAINNStack.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hydragnn/models/PAINNStack.py b/hydragnn/models/PAINNStack.py index dd9f9ebf2..11f062461 100644 --- a/hydragnn/models/PAINNStack.py +++ b/hydragnn/models/PAINNStack.py @@ -44,7 +44,9 @@ def __init__( def _init_conv(self): last_layer = 1 == self.num_conv_layers - self.graph_convs.append(self.get_conv(self.input_dim, self.hidden_dim)) + self.graph_convs.append( + self.get_conv(self.input_dim, self.hidden_dim, last_layer) + ) self.feature_layers.append(nn.Identity()) for i in range(self.num_conv_layers - 1): last_layer = i == self.num_conv_layers - 2