Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@register_* decorator in GraphGym #3684

Merged
merged 6 commits into from
Dec 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions graphgym/custom_graphgym/act/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,4 @@ def forward(self, x):


register_act('swish', SWISH(inplace=cfg.mem.inplace))

register_act('lrelu_03',
nn.LeakyReLU(negative_slope=0.3, inplace=cfg.mem.inplace))
register_act('lrelu_03', nn.LeakyReLU(inplace=cfg.mem.inplace))
4 changes: 1 addition & 3 deletions graphgym/custom_graphgym/config/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch_geometric.graphgym.register import register_config


@register_config('example')
def set_cfg_example(cfg):
r'''
This function sets the default config value for customized options
Expand All @@ -21,6 +22,3 @@ def set_cfg_example(cfg):

# then argument can be specified within the group
cfg.example_group.example_arg = 'example'


register_config('example', set_cfg_example)
8 changes: 2 additions & 6 deletions graphgym/custom_graphgym/encoder/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ogb.utils.features import get_bond_feature_dims


@register_node_encoder('example')
class ExampleNodeEncoder(torch.nn.Module):
"""
Provides an encoder for integer node features
Expand All @@ -25,9 +26,7 @@ def forward(self, batch):
return batch


register_node_encoder('example', ExampleNodeEncoder)


@register_edge_encoder('example')
class ExampleEdgeEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super().__init__()
Expand All @@ -48,6 +47,3 @@ def forward(self, batch):

batch.edge_attr = bond_embedding
return batch


register_edge_encoder('example', ExampleEdgeEncoder)
4 changes: 1 addition & 3 deletions graphgym/custom_graphgym/head/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch_geometric.graphgym.register import register_head


@register_head('head')
class ExampleNodeHead(nn.Module):
'''Head of GNN, node prediction'''
def __init__(self, dim_in, dim_out):
Expand All @@ -20,6 +21,3 @@ def forward(self, batch):
batch = self.layer_post_mp(batch)
pred, label = self._apply_index(batch)
return pred, label


register_head('example', ExampleNodeHead)
10 changes: 2 additions & 8 deletions graphgym/custom_graphgym/layer/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# Example 1: Directly define a GraphGym format Conv
# take 'batch' as input and 'batch' as output
@register_layer('exampleconv1')
class ExampleConv1(MessagePassing):
r"""Example GNN layer

Expand Down Expand Up @@ -54,10 +55,6 @@ def update(self, aggr_out):
return aggr_out


# Remember to register your layer!
register_layer('exampleconv1', ExampleConv1)


# Example 2: First define a PyG format Conv layer
# Then wrap it to become GraphGym format
class ExampleConv2Layer(MessagePassing):
Expand Down Expand Up @@ -98,6 +95,7 @@ def update(self, aggr_out):
return aggr_out


@register_layer('exampleconv2')
class ExampleConv2(nn.Module):
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super().__init__()
Expand All @@ -106,7 +104,3 @@ def __init__(self, dim_in, dim_out, bias=False, **kwargs):
def forward(self, batch):
batch.x = self.model(batch.x, batch.edge_index)
return batch


# Remember to register your layer!
register_layer('exampleconv2', ExampleConv2)
4 changes: 1 addition & 3 deletions graphgym/custom_graphgym/loader/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from torch_geometric.graphgym.register import register_loader


@register_loader('example')
def load_dataset_example(format, name, dataset_dir):
dataset_dir = f'{dataset_dir}/{name}'
if format == 'PyG':
if name == 'QM7b':
dataset_raw = QM7b(dataset_dir)
return dataset_raw


register_loader('example', load_dataset_example)
11 changes: 4 additions & 7 deletions graphgym/custom_graphgym/loss/example.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import torch.nn as nn

from torch_geometric.graphgym.register import register_loss
import torch

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_loss


@register_loss('smoothl1')
def loss_example(pred, true):
if cfg.model.loss_fun == 'smoothl1':
l1_loss = nn.SmoothL1Loss()
l1_loss = torch.nn.SmoothL1Loss()
loss = l1_loss(pred, true)
return loss, pred


register_loss('smoothl1', loss_example)
4 changes: 1 addition & 3 deletions graphgym/custom_graphgym/network/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch_geometric.graphgym.register import register_network


@register_network('example')
class ExampleGNN(torch.nn.Module):
def __init__(self, dim_in, dim_out, num_layers=2, model_type='GCN'):
super().__init__()
Expand Down Expand Up @@ -44,6 +45,3 @@ def forward(self, batch):
batch = self.post_mp(batch)

return batch


register_network('example', ExampleGNN)
12 changes: 4 additions & 8 deletions graphgym/custom_graphgym/optimizer/example.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
import torch.optim as optim

from torch_geometric.graphgym.register import register_optimizer, \
register_scheduler
from torch_geometric.graphgym.register import (register_optimizer,
register_scheduler)

from torch_geometric.graphgym.optimizer import OptimizerConfig, SchedulerConfig


@register_optimizer('adagrad')
def optimizer_example(params, optimizer_config: OptimizerConfig):
if optimizer_config.optimizer == 'adagrad':
optimizer = optim.Adagrad(params, lr=optimizer_config.base_lr,
weight_decay=optimizer_config.weight_decay)
return optimizer


register_optimizer('adagrad', optimizer_example)


@register_scheduler('reduce')
def scheduler_example(optimizer, scheduler_config: SchedulerConfig):
if scheduler_config.scheduler == 'reduce':
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return scheduler


register_scheduler('reduce', scheduler_example)
4 changes: 1 addition & 3 deletions graphgym/custom_graphgym/pooling/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from torch_geometric.graphgym.register import register_pooling


@register_pooling('example')
def global_example_pool(x, batch, size=None):
size = batch.max().item() + 1 if size is None else size
return scatter(x, batch, dim=0, dim_size=size, reduce='add')


register_pooling('example', global_example_pool)
4 changes: 1 addition & 3 deletions graphgym/custom_graphgym/stage/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def GNNLayer(dim_in, dim_out, has_act=True):
return GeneralLayer(cfg.gnn.layer_type, dim_in, dim_out, has_act)


@register_stage('example')
class GNNStackStage(nn.Module):
'''Simple Stage that stack GNN layers'''
def __init__(self, dim_in, dim_out, num_layers):
Expand All @@ -26,6 +27,3 @@ def forward(self, batch):
if cfg.gnn.l2norm:
batch.x = F.normalize(batch.x, p=2, dim=-1)
return batch


register_stage('example', GNNStackStage)
4 changes: 1 addition & 3 deletions graphgym/custom_graphgym/train/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def eval_epoch(logger, loader, model):
time_start = time.time()


@register_train('example')
def train_example(loggers, loaders, model, optimizer, scheduler):
start_epoch = 0
if cfg.train.auto_resume:
Expand All @@ -69,6 +70,3 @@ def train_example(loggers, loaders, model, optimizer, scheduler):
clean_ckpt()

logging.info('Task done, results saved in %s', cfg.run_dir)


register_train('example', train_example)
4 changes: 1 addition & 3 deletions torch_geometric/graphgym/contrib/train/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def eval_epoch(logger, loader, model, split='val'):
time_start = time.time()


@register_train('bench')
def train(loggers, loaders, model, optimizer, scheduler):
start_epoch = 0
if cfg.train.auto_resume:
Expand Down Expand Up @@ -88,6 +89,3 @@ def train(loggers, loaders, model, optimizer, scheduler):
clean_ckpt()

logging.info('Task done, results saved in {}'.format(cfg.run_dir))


register_train('bench', train)
13 changes: 3 additions & 10 deletions torch_geometric/graphgym/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import GNN
from torch_geometric.graphgym.register import register_network, network_dict

import torch_geometric.graphgym.register as register

network_dict = {
'gnn': GNN,
}
register.network_dict = {**register.network_dict, **network_dict}
register_network('gnn', GNN)


def create_model(to_device=True, dim_in=None, dim_out=None):
Expand All @@ -19,17 +15,14 @@ def create_model(to_device=True, dim_in=None, dim_out=None):
to_device (string): The devide that the model will be transferred to
dim_in (int, optional): Input dimension to the model
dim_out (int, optional): Output dimension to the model


"""
dim_in = cfg.share.dim_in if dim_in is None else dim_in
dim_out = cfg.share.dim_out if dim_out is None else dim_out
# binary classification, output dim = 1
if 'classification' in cfg.dataset.task_type and dim_out == 2:
dim_out = 1

model = register.network_dict[cfg.model.type](dim_in=dim_in,
dim_out=dim_out)
model = network_dict[cfg.model.type](dim_in=dim_in, dim_out=dim_out)
if to_device:
model.to(torch.device(cfg.device))
return model
21 changes: 9 additions & 12 deletions torch_geometric/graphgym/models/act.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import torch.nn as nn
from torch_geometric.graphgym.config import cfg
import torch_geometric.graphgym.register as register

act_dict = {
'relu': nn.ReLU(inplace=cfg.mem.inplace),
'selu': nn.SELU(inplace=cfg.mem.inplace),
'prelu': nn.PReLU(),
'elu': nn.ELU(inplace=cfg.mem.inplace),
'lrelu_01': nn.LeakyReLU(negative_slope=0.1, inplace=cfg.mem.inplace),
'lrelu_025': nn.LeakyReLU(negative_slope=0.25, inplace=cfg.mem.inplace),
'lrelu_05': nn.LeakyReLU(negative_slope=0.5, inplace=cfg.mem.inplace),
}
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_act

register.act_dict = {**register.act_dict, **act_dict}
register_act('relu', nn.ReLU(inplace=cfg.mem.inplace))
register_act('selu', nn.SELU(inplace=cfg.mem.inplace))
register_act('prelu', nn.PReLU())
register_act('elu', nn.ELU(inplace=cfg.mem.inplace))
register_act('lrelu_01', nn.LeakyReLU(0.1, inplace=cfg.mem.inplace))
register_act('lrelu_025', nn.LeakyReLU(0.25, inplace=cfg.mem.inplace))
register_act('lrelu_05', nn.LeakyReLU(0.5, inplace=cfg.mem.inplace))
21 changes: 5 additions & 16 deletions torch_geometric/graphgym/models/encoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch

import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.register import (register_node_encoder,
register_edge_encoder)


@register_node_encoder('Integer')
class IntegerFeatureEncoder(torch.nn.Module):
"""
Provides an encoder for integer node features.
Expand All @@ -26,6 +28,7 @@ def forward(self, batch):
return batch


@register_node_encoder('Atom')
class AtomEncoder(torch.nn.Module):
"""
The atom Encoder used in OGB molecule dataset.
Expand Down Expand Up @@ -56,6 +59,7 @@ def forward(self, batch):
return batch


@register_edge_encoder('Bond')
class BondEncoder(torch.nn.Module):
"""
The bond Encoder used in OGB molecule dataset.
Expand Down Expand Up @@ -84,18 +88,3 @@ def forward(self, batch):

batch.edge_attr = bond_embedding
return batch


node_encoder_dict = {'Integer': IntegerFeatureEncoder, 'Atom': AtomEncoder}

register.node_encoder_dict = {
**register.node_encoder_dict,
**node_encoder_dict
}

edge_encoder_dict = {'Bond': BondEncoder}

register.edge_encoder_dict = {
**register.edge_encoder_dict,
**edge_encoder_dict
}
16 changes: 4 additions & 12 deletions torch_geometric/graphgym/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
import torch.nn.functional as F

from torch_geometric.graphgym.config import cfg
import torch_geometric.graphgym.models.head
from torch_geometric.graphgym.models.layer import (new_layer_config,
GeneralLayer,
GeneralMultiLayer,
BatchNorm1dNode)
from torch_geometric.graphgym.init import init_weights

import torch_geometric.graphgym.models.encoder # noqa, register module
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.register import register_stage


def GNNLayer(dim_in, dim_out, has_act=True):
Expand Down Expand Up @@ -46,6 +44,9 @@ def GNNPreMP(dim_in, dim_out, num_layers):
has_act=False, has_bias=False, cfg=cfg))


@register_stage('stack')
@register_stage('skipsum')
@register_stage('skipconcat')
class GNNStackStage(nn.Module):
"""
Simple Stage that stack GNN layers
Expand Down Expand Up @@ -81,15 +82,6 @@ def forward(self, batch):
return batch


stage_dict = {
'stack': GNNStackStage,
'skipsum': GNNStackStage,
'skipconcat': GNNStackStage,
}

register.stage_dict = {**register.stage_dict, **stage_dict}


class FeatureEncoder(nn.Module):
"""
Encoding node and edge features
Expand Down
Loading