Skip to content

Commit

Permalink
Merge pull request #16 from OxfordRSE/fix_linting_for_embedding
Browse files Browse the repository at this point in the history
Fix linting for embedding
  • Loading branch information
mihaeladuta authored Nov 13, 2024
2 parents 90a3e63 + d1ec263 commit a8fcb7c
Show file tree
Hide file tree
Showing 24 changed files with 327 additions and 160 deletions.
1 change: 0 additions & 1 deletion l2gv2/embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

3 changes: 2 additions & 1 deletion l2gv2/embedding/dgi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
""" Imports for the DGI. """
from .models import DGI
from .utils import DGILoss
from .utils.loss import DGILoss
63 changes: 33 additions & 30 deletions l2gv2/embedding/dgi/execute.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
""" TODO: module docstring for dgi/execute.py. """
import argparse
import torch
import torch.nn as nn
from torch import nn
import torch_geometric as tg
import argparse


from .models import DGI, LogReg
from .utils.loss import DGILoss


parser = argparse.ArgumentParser(description="DGI test script")
parser.add_argument('--datapath', default='/tmp/cora')
parser.add_argument("--datapath", default="/tmp/cora")
args = parser.parse_args()

dataset = 'cora'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DATASET = "Cora"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

loss_fun = DGILoss()

# training params
batch_size = 1
nb_epochs = 10000
patience = 20
lr = 0.001
l2_coef = 0.0
drop_prob = 0.0
hid_units = 512
sparse = True
nonlinearity = 'prelu' # special name to separate parameters

data = tg.datasets.Planetoid(name='Cora', root=args.datapath)[0]
BATCH_SIZE = 1
NB_EPOCHS = 10000
PATIENCE = 20
LR = 0.001
L2_COEF = 0.0
DROP_PROB = 0.0
HID_UNITS = 512
SPARSE = True
NONLINEARITY = "prelu" # special name to separate parameters

data = tg.datasets.Planetoid(name=DATASET, root=args.datapath)[0]
data = data.to(device)
r_sum = data.x.sum(dim=1)
r_sum[r_sum == 0] = 1.0 # avoid division by zero
Expand Down Expand Up @@ -56,40 +57,43 @@
# idx_val = torch.LongTensor(idx_val)
# idx_test = torch.LongTensor(idx_test)

model = DGI(ft_size, hid_units, nonlinearity)
model = DGI(ft_size, HID_UNITS, NONLINEARITY)
model = model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef)
optimiser = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=L2_COEF)

STATE_DICT = "best_dgi.pkl"

xent = nn.CrossEntropyLoss()
# pylint: disable=invalid-name
cnt_wait = 0
best = 1e9
best_t = 0

for epoch in range(nb_epochs):
for epoch in range(NB_EPOCHS):
model.train()
optimiser.zero_grad()
loss = loss_fun(model, data)

print('Loss:', loss)
print("Loss:", loss)

if loss < best:
best = loss
best_t = epoch
cnt_wait = 0
torch.save(model.state_dict(), 'best_dgi.pkl')
torch.save(model.state_dict(), STATE_DICT)
else:
cnt_wait += 1

if cnt_wait == patience:
print('Early stopping!')
if cnt_wait == PATIENCE:
print("Early stopping!")
break

loss.backward()
optimiser.step()
# pylint: enable=invalid-name

print('Loading {}th epoch'.format(best_t))
model.load_state_dict(torch.load('best_dgi.pkl'))
print(f"Loading {best_t}th epoch")
model.load_state_dict(torch.load(STATE_DICT))

embeds = model.embed(data)
train_embs = embeds[data.train_mask]
Expand All @@ -104,13 +108,12 @@
accs = []

for _ in range(50):
log = LogReg(hid_units, nb_classes).to(device)
log = LogReg(HID_UNITS, nb_classes).to(device)

opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)

pat_steps = 0
best_acc = torch.zeros(1, device=device)

for _ in range(100):
log.train()
opt.zero_grad()
Expand All @@ -128,7 +131,7 @@
print(acc)
tot += acc

print('Average accuracy:', tot / 50)
print("Average accuracy:", tot / 50)

accs = torch.stack(accs)
print(accs.mean())
Expand Down
1 change: 1 addition & 0 deletions l2gv2/embedding/dgi/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" TODO: module docstring for dgi/layers/__init__.py. """
from .gcn import GCN
from .readout import AvgReadout
from .discriminator import Discriminator
9 changes: 6 additions & 3 deletions l2gv2/embedding/dgi/layers/discriminator.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
""" TODO: module docstring for dgi/layers/discriminator.py. """
import torch
import torch.nn as nn
from torch import nn


class Discriminator(nn.Module):
""" TODO: class docstring for Discriminator. """
def __init__(self, n_h):
super(Discriminator, self).__init__()
super().__init__()
self.f_k = nn.Bilinear(n_h, n_h, 1)
self.reset_parameters()

def reset_parameters(self):
""" TODO: method docstring for Discriminator.reset_parameters. """
for m in self.modules():
if isinstance(m, nn.Bilinear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)

def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None):
""" TODO: method docstring for Discriminator.forward. """
c_x = torch.unsqueeze(c, 0)
c_x = c_x.expand_as(h_pl)

Expand All @@ -30,4 +34,3 @@ def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None):
logits = torch.cat((sc_1, sc_2), 0)

return logits

15 changes: 9 additions & 6 deletions l2gv2/embedding/dgi/layers/gcn.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import torch.nn as nn
""" TODO: module docstring for dgi/layers/gcn.py. """
from torch import nn
import torch_geometric.nn as tg_nn


class GCN(nn.Module):
""" TODO: class docstring for GCN. """
def __init__(self, in_ft, out_ft, act, bias=True):
super(GCN, self).__init__()
super().__init__()
self.conv = tg_nn.GCNConv(in_channels=in_ft, out_channels=out_ft, bias=bias)
self.act = nn.PReLU() if act == 'prelu' else act
self.act = nn.PReLU() if act == "prelu" else act
self.reset_parameters()

def reset_parameters(self):
""" TODO: method docstring for GCN.reset_parameters. """
self.conv.reset_parameters()
if hasattr(self.act, 'reset_parameters'):
if hasattr(self.act, "reset_parameters"):
self.act.reset_parameters()
elif isinstance(self.act, nn.PReLU):
self.act.weight.data.fill_(0.25)

# Shape of seq: (batch, nodes, features)
def forward(self, seq, adj):
""" TODO: method docstring for GCN.forward. """
out = self.conv(seq, adj)

return self.act(out)

return self.act(out)
14 changes: 9 additions & 5 deletions l2gv2/embedding/dgi/layers/readout.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
"""TODO: module docstring for dgi/layers/readout.py."""

import torch
import torch.nn as nn
from torch import nn


# Applies an average on seq, of shape (batch, nodes, features)
# While taking into account the masking of msk
class AvgReadout(nn.Module):
"""TODO: class docstring for AvgReadout."""

def __init__(self):
super(AvgReadout, self).__init__()
super().__init__()

def forward(self, seq, msk):
"""TODO: method docstring for AvgReadout.forward."""
if msk is None:
return torch.mean(seq, 0)
else:
msk = torch.unsqueeze(msk, -1)
return torch.sum(seq * msk, 0) / torch.sum(msk)

msk = torch.unsqueeze(msk, -1)
return torch.sum(seq * msk, 0) / torch.sum(msk)
1 change: 1 addition & 0 deletions l2gv2/embedding/dgi/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
""" TODO: module docstring for dgi/models/__init__.py. """
from .dgi import DGI
from .logreg import LogReg
16 changes: 10 additions & 6 deletions l2gv2/embedding/dgi/models/dgi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch.nn as nn
""" TODO: module docstring for dgi/models/dgi.py. """
from torch import nn
from ..layers import GCN, AvgReadout, Discriminator


class DGI(nn.Module):
def __init__(self, n_in, n_h, activation='prelu'):
super(DGI, self).__init__()
""" TODO: class docstring for DGI. """
def __init__(self, n_in, n_h, activation="prelu"):
super().__init__()
self.gcn = GCN(n_in, n_h, activation)
self.read = AvgReadout()

Expand All @@ -13,11 +15,13 @@ def __init__(self, n_in, n_h, activation='prelu'):
self.disc = Discriminator(n_h)

def reset_parameters(self):
""" TODO: method docstring for DGI.reset_parameters. """
for m in self.children():
if hasattr(m, 'reset_parameters'):
if hasattr(m, "reset_parameters"):
m.reset_parameters()

def forward(self, seq1, seq2, adj, msk, samp_bias1, samp_bias2):
""" TODO: method docstring for DGI.forward. """
h_1 = self.gcn(seq1, adj)

c = self.read(h_1, msk)
Expand All @@ -30,8 +34,8 @@ def forward(self, seq1, seq2, adj, msk, samp_bias1, samp_bias2):
return ret

# Detach the return variables
def embed(self, data, msk=None):
def embed(self, data):
""" TODO: method docstring for DGI.embed. """
h_1 = self.gcn(data.x, data.edge_index)

return h_1.detach()

11 changes: 7 additions & 4 deletions l2gv2/embedding/dgi/models/logreg.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
""" TODO: module docstring for dgi/models/logreg.py. """
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn


class LogReg(nn.Module):
""" TODO: class docstring for LogReg. """
def __init__(self, ft_in, nb_classes):
super(LogReg, self).__init__()
super().__init__()
self.fc = nn.Linear(ft_in, nb_classes)

for m in self.modules():
self.weights_init(m)

def weights_init(self, m):
""" TODO: method docstring for LogReg.weights_init. """
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)

def forward(self, seq):
""" TODO: method docstring for LogReg.forward """
ret = self.fc(seq)
return ret

1 change: 1 addition & 0 deletions l2gv2/embedding/dgi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
""" TODO: module docstring for dgi/utils/__init__.py. """
from .loss import DGILoss
3 changes: 3 additions & 0 deletions l2gv2/embedding/dgi/utils/loss.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
""" TODO: module docstring for dgi/utils/loss.py. """
import torch_geometric as tg
import torch


class DGILoss(torch.nn.Module):
""" TODO: class docstring for DGILoss. """
def __init__(self):
super().__init__()
self.loss_fun = torch.nn.BCEWithLogitsLoss()

def forward(self, model, data: tg.data.Data):
""" TODO: method docstring for DGILoss.forward. """
device = data.edge_index.device
nb_nodes = data.num_nodes
idx = torch.randperm(nb_nodes, device=device)
Expand Down
Loading

0 comments on commit a8fcb7c

Please sign in to comment.