-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a783b85
commit d1ec263
Showing
24 changed files
with
327 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
""" Imports for the DGI. """ | ||
from .models import DGI | ||
from .utils import DGILoss | ||
from .utils.loss import DGILoss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
""" TODO: module docstring for dgi/layers/__init__.py. """ | ||
from .gcn import GCN | ||
from .readout import AvgReadout | ||
from .discriminator import Discriminator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,27 @@ | ||
import torch.nn as nn | ||
""" TODO: module docstring for dgi/layers/gcn.py. """ | ||
from torch import nn | ||
import torch_geometric.nn as tg_nn | ||
|
||
|
||
class GCN(nn.Module): | ||
""" TODO: class docstring for GCN. """ | ||
def __init__(self, in_ft, out_ft, act, bias=True): | ||
super(GCN, self).__init__() | ||
super().__init__() | ||
self.conv = tg_nn.GCNConv(in_channels=in_ft, out_channels=out_ft, bias=bias) | ||
self.act = nn.PReLU() if act == 'prelu' else act | ||
self.act = nn.PReLU() if act == "prelu" else act | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
""" TODO: method docstring for GCN.reset_parameters. """ | ||
self.conv.reset_parameters() | ||
if hasattr(self.act, 'reset_parameters'): | ||
if hasattr(self.act, "reset_parameters"): | ||
self.act.reset_parameters() | ||
elif isinstance(self.act, nn.PReLU): | ||
self.act.weight.data.fill_(0.25) | ||
|
||
# Shape of seq: (batch, nodes, features) | ||
def forward(self, seq, adj): | ||
""" TODO: method docstring for GCN.forward. """ | ||
out = self.conv(seq, adj) | ||
|
||
return self.act(out) | ||
|
||
return self.act(out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,21 @@ | ||
"""TODO: module docstring for dgi/layers/readout.py.""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch import nn | ||
|
||
|
||
# Applies an average on seq, of shape (batch, nodes, features) | ||
# While taking into account the masking of msk | ||
class AvgReadout(nn.Module): | ||
"""TODO: class docstring for AvgReadout.""" | ||
|
||
def __init__(self): | ||
super(AvgReadout, self).__init__() | ||
super().__init__() | ||
|
||
def forward(self, seq, msk): | ||
"""TODO: method docstring for AvgReadout.forward.""" | ||
if msk is None: | ||
return torch.mean(seq, 0) | ||
else: | ||
msk = torch.unsqueeze(msk, -1) | ||
return torch.sum(seq * msk, 0) / torch.sum(msk) | ||
|
||
msk = torch.unsqueeze(msk, -1) | ||
return torch.sum(seq * msk, 0) / torch.sum(msk) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
""" TODO: module docstring for dgi/models/__init__.py. """ | ||
from .dgi import DGI | ||
from .logreg import LogReg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,25 @@ | ||
""" TODO: module docstring for dgi/models/logreg.py. """ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
|
||
class LogReg(nn.Module): | ||
""" TODO: class docstring for LogReg. """ | ||
def __init__(self, ft_in, nb_classes): | ||
super(LogReg, self).__init__() | ||
super().__init__() | ||
self.fc = nn.Linear(ft_in, nb_classes) | ||
|
||
for m in self.modules(): | ||
self.weights_init(m) | ||
|
||
def weights_init(self, m): | ||
""" TODO: method docstring for LogReg.weights_init. """ | ||
if isinstance(m, nn.Linear): | ||
torch.nn.init.xavier_uniform_(m.weight.data) | ||
if m.bias is not None: | ||
m.bias.data.fill_(0.0) | ||
|
||
def forward(self, seq): | ||
""" TODO: method docstring for LogReg.forward """ | ||
ret = self.fc(seq) | ||
return ret | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
""" TODO: module docstring for dgi/utils/__init__.py. """ | ||
from .loss import DGILoss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.