Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mpact][compiler] add gcn model (with test) #41

Merged
merged 3 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/mpact/models/gcn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.nn.functional as F


class GraphConv(torch.nn.Module):
Expand All @@ -19,5 +20,28 @@ def forward(self, inp, adj_mat):
return output


class GCN(torch.nn.Module):
"""
Graph Convolutional Network (GCN) inspired by <https://arxiv.org/pdf/1609.02907.pdf>.
"""

def __init__(self, input_dim, hidden_dim, output_dim, dropout_p=0.1):
super(GCN, self).__init__()
self.gc1 = GraphConv(input_dim, hidden_dim)
self.gc2 = GraphConv(hidden_dim, output_dim)
self.dropout = torch.nn.Dropout(dropout_p)

def forward(self, input_tensor, adj_mat):
x = self.gc1(input_tensor, adj_mat)
x = F.relu(x)
x = self.dropout(x)
x = self.gc2(x, adj_mat)
return F.log_softmax(x, dim=1)


def graphconv44():
return GraphConv(input_dim=4, output_dim=4)


def gcn4164():
return GCN(input_dim=4, hidden_dim=16, output_dim=4)
34 changes: 33 additions & 1 deletion test/python/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run

from mpact.models.gcn import graphconv44
from mpact.models.gcn import graphconv44, gcn4164

net = graphconv44()
net.eval() # Switch to inference.

# Get random (but reproducible) matrices.
torch.manual_seed(0)
Expand Down Expand Up @@ -60,3 +61,34 @@
print("mpact run")
res = mpact_jit_run(invoker, fn, inp, adj_mat)
print(res)

net = gcn4164()
net.eval() # Switch to inference.


# Sparse input.
idx = torch.tensor([[0, 0, 1, 2], [0, 2, 3, 1]], dtype=torch.int64)
val = torch.tensor([14.0, 3.0, -8.0, 11.0], dtype=torch.float32)
S = torch.sparse_coo_tensor(idx, val, size=[4, 4])

#
# CHECK: pytorch gcn
# CHECK: tensor({{\[}}[-1.3863, -1.3863, -1.3863, -1.3863],
# CHECK: [-1.3863, -1.3863, -1.3863, -1.3863],
# CHECK: [-1.3863, -1.3863, -1.3863, -1.3863],
# CHECK: [-1.3863, -1.3863, -1.3863, -1.3863]])
# CHECK: mpact gcn
# CHECK: {{\[}}[-1.3862944 -1.3862944 -1.3862944 -1.3862944]
# CHECK: [-1.3862944 -1.3862944 -1.3862944 -1.3862944]
# CHECK: [-1.3862944 -1.3862944 -1.3862944 -1.3862944]
# CHECK: [-1.3862944 -1.3862944 -1.3862944 -1.3862944]{{\]}}
#
with torch.no_grad():
# Run it with PyTorch.
print("pytorch gcn")
res = net(S, adj_mat)
print(res)

print("mpact gcn")
res = mpact_jit(net, S, adj_mat)
print(res)
2 changes: 1 addition & 1 deletion test/python/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mpact.models.resnet import resnet20

resnet = resnet20()
resnet.train(False) # switch to inference
resnet.eval() # Switch to inference.

# Get a random input.
# B x RGB x H x W
Expand Down
Loading