Skip to content

Commit

Permalink
Fix coefficient initialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Hale committed Jul 22, 2024
1 parent c152df3 commit 56404c4
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions sgw_torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.nn.inits import zeros, ones
from torch_geometric.typing import OptTensor
import sgw_torch
import numpy as np
Expand Down Expand Up @@ -179,7 +179,16 @@ def __init__(
):
super().__init__(in_channels, out_channels, K, lap_type, bias)

def convert_coefficients(self, ys):
def reset_parameters(self):
super().reset_parameters()
# gamma_j should be an estimation of filter value h(x_j), which should be positive so initialize to positive value
for lin in self.lins:
ones(lin.weight)

def convert_coefficients(self, ys=None):
if ys is None:
ys = list(self.parameters())

k1 = len(self.lins)

def transform(Ts):
Expand All @@ -194,7 +203,7 @@ def transform(Ts):
x = torch.from_numpy(np.polynomial.chebyshev.chebpts1(k1))
a = torch.ones(k1)
b = x
ws.append(transform(a/2))
ws.append(transform(a))
ws.append(transform(b))
for _ in range(2, k1):
T = 2*x*b - a
Expand All @@ -211,6 +220,7 @@ def forward(
batch: OptTensor = None,
lambda_max: OptTensor = None,
) -> Tensor:
ys = [lin.weight for lin in self.lins]
# gamma_j (ys[j]) should be an estimation of filter value h(x_j), which should be positive so relu it
ys = [F.relu(lin.weight) for lin in self.lins]
ws = self.convert_coefficients(ys)
return self._evaluate_chebyshev(ws, x, edge_index, edge_weight, batch, lambda_max)

0 comments on commit 56404c4

Please sign in to comment.