Skip to content

Commit

Permalink
deadline
Browse files Browse the repository at this point in the history
  • Loading branch information
richieBao committed Nov 5, 2023
1 parent a4cad5d commit e43cb36
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 3 deletions.
Binary file modified dist/usda-0.0.29.tar.gz
Binary file not shown.
2 changes: 2 additions & 0 deletions src/usda/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ._gnn_algorithms import VanillaGNNLayer
from ._gnn_algorithms import VanillaGNN
from ._gnn_algorithms import GCN
from ._gnn_algorithms import GATv2

__all__ = [
"G_drawing",
Expand All @@ -37,5 +38,6 @@
"VanillaGNNLayer",
"VanillaGNN",
"GCN",
"GATv2",
]

Binary file modified src/usda/network/__pycache__/_gnn_interpretation.cpython-311.pyc
Binary file not shown.
46 changes: 43 additions & 3 deletions src/usda/network/_gnn_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
"""
import torch
torch.manual_seed(0)
from torch.nn import Linear
from torch.nn import Linear, Dropout
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

from torch_geometric.nn import GATv2Conv, GCNConv


def accuracy(y_pred, y_true):
"""Calculate accuracy."""
Expand Down Expand Up @@ -113,4 +115,42 @@ def test(self, data):
out = self(data.x, data.edge_index)
acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
return acc


class GATv2(torch.nn.Module):
def __init__(self, dim_in, dim_h, dim_out, heads=8):
super().__init__()
self.gat1 = GATv2Conv(dim_in, dim_h, heads=heads)
self.gat2 = GATv2Conv(dim_h*heads, dim_out, heads=1)

def forward(self, x, edge_index):
h = F.dropout(x, p=0.6, training=self.training)
h = self.gat1(h, edge_index)
h = F.elu(h)
h = F.dropout(h, p=0.6, training=self.training)
h = self.gat2(h, edge_index)
return F.log_softmax(h, dim=1)

def fit(self, data, epochs,verbose=20):
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.01)

self.train()
for epoch in range(epochs+1):
optimizer.zero_grad()
out = self(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
acc = accuracy(out[data.train_mask].argmax(dim=1), data.y[data.train_mask])
loss.backward()
optimizer.step()

if(epoch % verbose == 0):
val_loss = criterion(out[data.val_mask], data.y[data.val_mask])
val_acc = accuracy(out[data.val_mask].argmax(dim=1), data.y[data.val_mask])
print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Train Acc: {acc*100:>5.2f}% | Val Loss: {val_loss:.2f} | Val Acc: {val_acc*100:.2f}%')

@torch.no_grad()
def test(self, data):
self.eval()
out = self(data.x, data.edge_index)
acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
return acc

0 comments on commit e43cb36

Please sign in to comment.