-
Notifications
You must be signed in to change notification settings - Fork 8
/
models.py
51 lines (45 loc) · 1.67 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File name: models.py
Author: locke
Date created: 2018/10/5 下午2:37
"""
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import *
class GAT(nn.Module):
def __init__(self, n_units, n_heads, dropout, attn_dropout, instance_normalization, diag):
super(GAT, self).__init__()
self.num_layer = len(n_units) - 1
self.dropout = dropout
self.inst_norm = instance_normalization
if self.inst_norm:
self.norm = nn.InstanceNorm1d(n_units[0], momentum=0.0, affine=True)
self.layer_stack = nn.ModuleList()
self.diag = diag
for i in range(self.num_layer):
f_in = n_units[i] * n_heads[i - 1] if i else n_units[i]
self.layer_stack.append(MultiHeadGraphAttention(n_heads[i], f_in, n_units[i + 1], attn_dropout, diag, nn.init.ones_, False))
def forward(self, x, adj):
if self.inst_norm:
x = self.norm(x)
for i, gat_layer in enumerate(self.layer_stack):
if i + 1 < self.num_layer:
x = F.dropout(x, self.dropout, training=self.training)
x = gat_layer(x, adj)
if self.diag:
x = x.mean(dim=0)
if i + 1 < self.num_layer:
if self.diag:
x = F.elu(x)
else:
x = F.elu(x.transpose(0, 1).contiguous().view(adj.size(0), -1))
if not self.diag:
x = x.mean(dim=0)
return x