-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgcn_layer.py
134 lines (108 loc) · 4.15 KB
/
gcn_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv, SAGEConv
class GCNLayer(nn.Module):
def __init__(self,
g,
in_dim,
out_dim,
activation,
normalize=True,
batch_norm=True,
dropout=0,
residual=True,
bias=True):
super(GCNLayer, self).__init__()
self.g = g
self.in_channels = in_dim
self.out_channels = out_dim
self.activation = activation
self.normalize = normalize
self.batch_norm = batch_norm
self.residual = residual
self.bias = bias
if self.in_channels != self.out_channels:
self.residual = False
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout(dropout)
self.batchnorm_h = nn.BatchNorm1d(self.out_channels)
if self.normalize:
gcn_norm = 'both'
else:
gcn_norm = 'none'
self.conv = GraphConv(self.in_channels, self.out_channels, norm=gcn_norm, bias=self.bias)
def forward(self, feature):
h_in = feature
if self.dropout is not None:
h = self.dropout(h_in)
h = self.conv(self.g, feature)
if self.batch_norm:
# combine all non-feature dimensions
shape_orig = h.shape
h = h.view(-1, h.shape[-1])
h = self.batchnorm_h(h)
h = h.view(*shape_orig)
if self.activation:
h = self.activation(h)
if self.residual:
h = h_in + h
return h
def __repr__(self):
return '{}(in_channels={}, out_channels={}, batch_norm={}, dropout={}, residual={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels,
self.batch_norm,
self.dropout,
self.residual)
class GraphSAGELayer(nn.Module):
def __init__(self,
g,
in_dim,
out_dim,
activation,
batch_norm=True,
dropout=0,
residual=True,
bias=True):
super(GraphSAGELayer, self).__init__()
self.g = g
self.in_channels = in_dim
self.out_channels = out_dim
self.activation = activation
self.aggregator = "mean"
self.batch_norm = batch_norm
self.residual = residual
self.bias = bias
if self.in_channels != self.out_channels:
self.residual = False
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout(dropout)
self.batchnorm_h = nn.BatchNorm1d(self.out_channels)
self.conv = SAGEConv(self.in_channels, self.out_channels, self.aggregator, bias=self.bias)
def forward(self, feature):
h_in = feature # to be used for residual connection
if self.dropout is not None:
h = self.dropout(h_in)
h = self.conv(self.g, feature)
if self.batch_norm:
# combine all non-feature dimensions
shape_orig = h.shape
h = h.view(-1, h.shape[-1])
h = self.batchnorm_h(h)
h = h.view(*shape_orig)
if self.activation:
h = self.activation(h)
if self.residual:
h = h_in + h # residual connection
return h
def __repr__(self):
return '{}(in_channels={}, out_channels={}, aggregator={}, batch_norm={}, dropout={}, residual={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels,
self.aggregator,
self.batch_norm,
self.dropout,
self.residual)