diff --git a/graphgym/custom_graphgym/act/example.py b/graphgym/custom_graphgym/act/example.py index 02a83b4ec4e3..b6b79bcccd1f 100644 --- a/graphgym/custom_graphgym/act/example.py +++ b/graphgym/custom_graphgym/act/example.py @@ -18,6 +18,4 @@ def forward(self, x): register_act('swish', SWISH(inplace=cfg.mem.inplace)) - -register_act('lrelu_03', - nn.LeakyReLU(negative_slope=0.3, inplace=cfg.mem.inplace)) +register_act('lrelu_03', nn.LeakyReLU(inplace=cfg.mem.inplace)) diff --git a/graphgym/custom_graphgym/config/example.py b/graphgym/custom_graphgym/config/example.py index 8f27399137cf..49374e16d2a7 100644 --- a/graphgym/custom_graphgym/config/example.py +++ b/graphgym/custom_graphgym/config/example.py @@ -3,6 +3,7 @@ from torch_geometric.graphgym.register import register_config +@register_config('example') def set_cfg_example(cfg): r''' This function sets the default config value for customized options @@ -21,6 +22,3 @@ def set_cfg_example(cfg): # then argument can be specified within the group cfg.example_group.example_arg = 'example' - - -register_config('example', set_cfg_example) diff --git a/graphgym/custom_graphgym/encoder/example.py b/graphgym/custom_graphgym/encoder/example.py index 4cdc5e3cc4d9..cda03c09756a 100644 --- a/graphgym/custom_graphgym/encoder/example.py +++ b/graphgym/custom_graphgym/encoder/example.py @@ -5,6 +5,7 @@ from ogb.utils.features import get_bond_feature_dims +@register_node_encoder('example') class ExampleNodeEncoder(torch.nn.Module): """ Provides an encoder for integer node features @@ -25,9 +26,7 @@ def forward(self, batch): return batch -register_node_encoder('example', ExampleNodeEncoder) - - +@register_edge_encoder('example') class ExampleEdgeEncoder(torch.nn.Module): def __init__(self, emb_dim): super().__init__() @@ -48,6 +47,3 @@ def forward(self, batch): batch.edge_attr = bond_embedding return batch - - -register_edge_encoder('example', ExampleEdgeEncoder) diff --git a/graphgym/custom_graphgym/head/example.py b/graphgym/custom_graphgym/head/example.py index 53452ac37a8e..25b3c2644388 100644 --- a/graphgym/custom_graphgym/head/example.py +++ b/graphgym/custom_graphgym/head/example.py @@ -3,6 +3,7 @@ from torch_geometric.graphgym.register import register_head +@register_head('head') class ExampleNodeHead(nn.Module): '''Head of GNN, node prediction''' def __init__(self, dim_in, dim_out): @@ -20,6 +21,3 @@ def forward(self, batch): batch = self.layer_post_mp(batch) pred, label = self._apply_index(batch) return pred, label - - -register_head('example', ExampleNodeHead) diff --git a/graphgym/custom_graphgym/layer/example.py b/graphgym/custom_graphgym/layer/example.py index 8464b771584a..231d0d7de5a8 100644 --- a/graphgym/custom_graphgym/layer/example.py +++ b/graphgym/custom_graphgym/layer/example.py @@ -13,6 +13,7 @@ # Example 1: Directly define a GraphGym format Conv # take 'batch' as input and 'batch' as output +@register_layer('exampleconv1') class ExampleConv1(MessagePassing): r"""Example GNN layer @@ -54,10 +55,6 @@ def update(self, aggr_out): return aggr_out -# Remember to register your layer! -register_layer('exampleconv1', ExampleConv1) - - # Example 2: First define a PyG format Conv layer # Then wrap it to become GraphGym format class ExampleConv2Layer(MessagePassing): @@ -98,6 +95,7 @@ def update(self, aggr_out): return aggr_out +@register_layer('exampleconv2') class ExampleConv2(nn.Module): def __init__(self, dim_in, dim_out, bias=False, **kwargs): super().__init__() @@ -106,7 +104,3 @@ def __init__(self, dim_in, dim_out, bias=False, **kwargs): def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch - - -# Remember to register your layer! -register_layer('exampleconv2', ExampleConv2) diff --git a/graphgym/custom_graphgym/loader/example.py b/graphgym/custom_graphgym/loader/example.py index dc88a297edbc..60acc70591b1 100644 --- a/graphgym/custom_graphgym/loader/example.py +++ b/graphgym/custom_graphgym/loader/example.py @@ -3,12 +3,10 @@ from torch_geometric.graphgym.register import register_loader +@register_loader('example') def load_dataset_example(format, name, dataset_dir): dataset_dir = f'{dataset_dir}/{name}' if format == 'PyG': if name == 'QM7b': dataset_raw = QM7b(dataset_dir) return dataset_raw - - -register_loader('example', load_dataset_example) diff --git a/graphgym/custom_graphgym/loss/example.py b/graphgym/custom_graphgym/loss/example.py index 0f5c69e9fad5..4e897a7ae8e3 100644 --- a/graphgym/custom_graphgym/loss/example.py +++ b/graphgym/custom_graphgym/loss/example.py @@ -1,15 +1,12 @@ -import torch.nn as nn - -from torch_geometric.graphgym.register import register_loss +import torch from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.register import register_loss +@register_loss('smoothl1') def loss_example(pred, true): if cfg.model.loss_fun == 'smoothl1': - l1_loss = nn.SmoothL1Loss() + l1_loss = torch.nn.SmoothL1Loss() loss = l1_loss(pred, true) return loss, pred - - -register_loss('smoothl1', loss_example) diff --git a/graphgym/custom_graphgym/network/example.py b/graphgym/custom_graphgym/network/example.py index 1bc8b44d2592..0d69263865aa 100644 --- a/graphgym/custom_graphgym/network/example.py +++ b/graphgym/custom_graphgym/network/example.py @@ -9,6 +9,7 @@ from torch_geometric.graphgym.register import register_network +@register_network('example') class ExampleGNN(torch.nn.Module): def __init__(self, dim_in, dim_out, num_layers=2, model_type='GCN'): super().__init__() @@ -44,6 +45,3 @@ def forward(self, batch): batch = self.post_mp(batch) return batch - - -register_network('example', ExampleGNN) diff --git a/graphgym/custom_graphgym/optimizer/example.py b/graphgym/custom_graphgym/optimizer/example.py index 5886c4e46d60..99be12917d38 100644 --- a/graphgym/custom_graphgym/optimizer/example.py +++ b/graphgym/custom_graphgym/optimizer/example.py @@ -1,11 +1,12 @@ import torch.optim as optim -from torch_geometric.graphgym.register import register_optimizer, \ - register_scheduler +from torch_geometric.graphgym.register import (register_optimizer, + register_scheduler) from torch_geometric.graphgym.optimizer import OptimizerConfig, SchedulerConfig +@register_optimizer('adagrad') def optimizer_example(params, optimizer_config: OptimizerConfig): if optimizer_config.optimizer == 'adagrad': optimizer = optim.Adagrad(params, lr=optimizer_config.base_lr, @@ -13,13 +14,8 @@ def optimizer_example(params, optimizer_config: OptimizerConfig): return optimizer -register_optimizer('adagrad', optimizer_example) - - +@register_scheduler('reduce') def scheduler_example(optimizer, scheduler_config: SchedulerConfig): if scheduler_config.scheduler == 'reduce': scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) return scheduler - - -register_scheduler('reduce', scheduler_example) diff --git a/graphgym/custom_graphgym/pooling/example.py b/graphgym/custom_graphgym/pooling/example.py index 0f6e899bcb70..bc09e7c34089 100644 --- a/graphgym/custom_graphgym/pooling/example.py +++ b/graphgym/custom_graphgym/pooling/example.py @@ -2,9 +2,7 @@ from torch_geometric.graphgym.register import register_pooling +@register_pooling('example') def global_example_pool(x, batch, size=None): size = batch.max().item() + 1 if size is None else size return scatter(x, batch, dim=0, dim_size=size, reduce='add') - - -register_pooling('example', global_example_pool) diff --git a/graphgym/custom_graphgym/stage/example.py b/graphgym/custom_graphgym/stage/example.py index e0c182549e46..7bf3e264aa7f 100644 --- a/graphgym/custom_graphgym/stage/example.py +++ b/graphgym/custom_graphgym/stage/example.py @@ -10,6 +10,7 @@ def GNNLayer(dim_in, dim_out, has_act=True): return GeneralLayer(cfg.gnn.layer_type, dim_in, dim_out, has_act) +@register_stage('example') class GNNStackStage(nn.Module): '''Simple Stage that stack GNN layers''' def __init__(self, dim_in, dim_out, num_layers): @@ -26,6 +27,3 @@ def forward(self, batch): if cfg.gnn.l2norm: batch.x = F.normalize(batch.x, p=2, dim=-1) return batch - - -register_stage('example', GNNStackStage) diff --git a/graphgym/custom_graphgym/train/example.py b/graphgym/custom_graphgym/train/example.py index 63bee71e98e4..e62c195b8eb4 100644 --- a/graphgym/custom_graphgym/train/example.py +++ b/graphgym/custom_graphgym/train/example.py @@ -44,6 +44,7 @@ def eval_epoch(logger, loader, model): time_start = time.time() +@register_train('example') def train_example(loggers, loaders, model, optimizer, scheduler): start_epoch = 0 if cfg.train.auto_resume: @@ -69,6 +70,3 @@ def train_example(loggers, loaders, model, optimizer, scheduler): clean_ckpt() logging.info('Task done, results saved in %s', cfg.run_dir) - - -register_train('example', train_example) diff --git a/torch_geometric/graphgym/contrib/train/benchmark.py b/torch_geometric/graphgym/contrib/train/benchmark.py index fc0350bb1671..814a038369fe 100644 --- a/torch_geometric/graphgym/contrib/train/benchmark.py +++ b/torch_geometric/graphgym/contrib/train/benchmark.py @@ -61,6 +61,7 @@ def eval_epoch(logger, loader, model, split='val'): time_start = time.time() +@register_train('bench') def train(loggers, loaders, model, optimizer, scheduler): start_epoch = 0 if cfg.train.auto_resume: @@ -88,6 +89,3 @@ def train(loggers, loaders, model, optimizer, scheduler): clean_ckpt() logging.info('Task done, results saved in {}'.format(cfg.run_dir)) - - -register_train('bench', train) diff --git a/torch_geometric/graphgym/model_builder.py b/torch_geometric/graphgym/model_builder.py index a2703d0a528d..00e58184ca62 100644 --- a/torch_geometric/graphgym/model_builder.py +++ b/torch_geometric/graphgym/model_builder.py @@ -2,13 +2,9 @@ from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.models.gnn import GNN +from torch_geometric.graphgym.register import register_network, network_dict -import torch_geometric.graphgym.register as register - -network_dict = { - 'gnn': GNN, -} -register.network_dict = {**register.network_dict, **network_dict} +register_network('gnn', GNN) def create_model(to_device=True, dim_in=None, dim_out=None): @@ -19,8 +15,6 @@ def create_model(to_device=True, dim_in=None, dim_out=None): to_device (string): The devide that the model will be transferred to dim_in (int, optional): Input dimension to the model dim_out (int, optional): Output dimension to the model - - """ dim_in = cfg.share.dim_in if dim_in is None else dim_in dim_out = cfg.share.dim_out if dim_out is None else dim_out @@ -28,8 +22,7 @@ def create_model(to_device=True, dim_in=None, dim_out=None): if 'classification' in cfg.dataset.task_type and dim_out == 2: dim_out = 1 - model = register.network_dict[cfg.model.type](dim_in=dim_in, - dim_out=dim_out) + model = network_dict[cfg.model.type](dim_in=dim_in, dim_out=dim_out) if to_device: model.to(torch.device(cfg.device)) return model diff --git a/torch_geometric/graphgym/models/act.py b/torch_geometric/graphgym/models/act.py index 239089f18899..3ba62fc521a8 100644 --- a/torch_geometric/graphgym/models/act.py +++ b/torch_geometric/graphgym/models/act.py @@ -1,15 +1,12 @@ import torch.nn as nn -from torch_geometric.graphgym.config import cfg -import torch_geometric.graphgym.register as register -act_dict = { - 'relu': nn.ReLU(inplace=cfg.mem.inplace), - 'selu': nn.SELU(inplace=cfg.mem.inplace), - 'prelu': nn.PReLU(), - 'elu': nn.ELU(inplace=cfg.mem.inplace), - 'lrelu_01': nn.LeakyReLU(negative_slope=0.1, inplace=cfg.mem.inplace), - 'lrelu_025': nn.LeakyReLU(negative_slope=0.25, inplace=cfg.mem.inplace), - 'lrelu_05': nn.LeakyReLU(negative_slope=0.5, inplace=cfg.mem.inplace), -} +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.register import register_act -register.act_dict = {**register.act_dict, **act_dict} +register_act('relu', nn.ReLU(inplace=cfg.mem.inplace)) +register_act('selu', nn.SELU(inplace=cfg.mem.inplace)) +register_act('prelu', nn.PReLU()) +register_act('elu', nn.ELU(inplace=cfg.mem.inplace)) +register_act('lrelu_01', nn.LeakyReLU(0.1, inplace=cfg.mem.inplace)) +register_act('lrelu_025', nn.LeakyReLU(0.25, inplace=cfg.mem.inplace)) +register_act('lrelu_05', nn.LeakyReLU(0.5, inplace=cfg.mem.inplace)) diff --git a/torch_geometric/graphgym/models/encoder.py b/torch_geometric/graphgym/models/encoder.py index 6667a6e6ee01..e3ea5a4d0995 100644 --- a/torch_geometric/graphgym/models/encoder.py +++ b/torch_geometric/graphgym/models/encoder.py @@ -1,8 +1,10 @@ import torch -import torch_geometric.graphgym.register as register +from torch_geometric.graphgym.register import (register_node_encoder, + register_edge_encoder) +@register_node_encoder('Integer') class IntegerFeatureEncoder(torch.nn.Module): """ Provides an encoder for integer node features. @@ -26,6 +28,7 @@ def forward(self, batch): return batch +@register_node_encoder('Atom') class AtomEncoder(torch.nn.Module): """ The atom Encoder used in OGB molecule dataset. @@ -56,6 +59,7 @@ def forward(self, batch): return batch +@register_edge_encoder('Bond') class BondEncoder(torch.nn.Module): """ The bond Encoder used in OGB molecule dataset. @@ -84,18 +88,3 @@ def forward(self, batch): batch.edge_attr = bond_embedding return batch - - -node_encoder_dict = {'Integer': IntegerFeatureEncoder, 'Atom': AtomEncoder} - -register.node_encoder_dict = { - **register.node_encoder_dict, - **node_encoder_dict -} - -edge_encoder_dict = {'Bond': BondEncoder} - -register.edge_encoder_dict = { - **register.edge_encoder_dict, - **edge_encoder_dict -} diff --git a/torch_geometric/graphgym/models/gnn.py b/torch_geometric/graphgym/models/gnn.py index 8b41800bfc05..102c603d2582 100644 --- a/torch_geometric/graphgym/models/gnn.py +++ b/torch_geometric/graphgym/models/gnn.py @@ -3,15 +3,13 @@ import torch.nn.functional as F from torch_geometric.graphgym.config import cfg -import torch_geometric.graphgym.models.head from torch_geometric.graphgym.models.layer import (new_layer_config, GeneralLayer, GeneralMultiLayer, BatchNorm1dNode) from torch_geometric.graphgym.init import init_weights - -import torch_geometric.graphgym.models.encoder # noqa, register module import torch_geometric.graphgym.register as register +from torch_geometric.graphgym.register import register_stage def GNNLayer(dim_in, dim_out, has_act=True): @@ -46,6 +44,9 @@ def GNNPreMP(dim_in, dim_out, num_layers): has_act=False, has_bias=False, cfg=cfg)) +@register_stage('stack') +@register_stage('skipsum') +@register_stage('skipconcat') class GNNStackStage(nn.Module): """ Simple Stage that stack GNN layers @@ -81,15 +82,6 @@ def forward(self, batch): return batch -stage_dict = { - 'stack': GNNStackStage, - 'skipsum': GNNStackStage, - 'skipconcat': GNNStackStage, -} - -register.stage_dict = {**register.stage_dict, **stage_dict} - - class FeatureEncoder(nn.Module): """ Encoding node and edge features diff --git a/torch_geometric/graphgym/models/head.py b/torch_geometric/graphgym/models/head.py index 3d61ce0a7d48..52cca613685d 100644 --- a/torch_geometric/graphgym/models/head.py +++ b/torch_geometric/graphgym/models/head.py @@ -7,11 +7,11 @@ from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.models.layer import new_layer_config, MLP -import torch_geometric.graphgym.models.pooling # noqa, register module - import torch_geometric.graphgym.register as register +from torch_geometric.graphgym.register import register_head +@register_head('node') class GNNNodeHead(nn.Module): """ GNN prediction head for node prediction tasks. @@ -37,6 +37,8 @@ def forward(self, batch): return pred, label +@register_head('edge') +@register_head('link_pred') class GNNEdgeHead(nn.Module): """ GNN prediction head for edge/link prediction tasks. @@ -87,6 +89,7 @@ def forward(self, batch): return pred, label +@register_head('graph') class GNNGraphHead(nn.Module): """ GNN prediction head for graph prediction tasks. @@ -113,14 +116,3 @@ def forward(self, batch): batch.graph_feature = graph_emb pred, label = self._apply_index(batch) return pred, label - - -# Head models for external interface -head_dict = { - 'node': GNNNodeHead, - 'edge': GNNEdgeHead, - 'link_pred': GNNEdgeHead, - 'graph': GNNGraphHead -} - -register.head_dict = {**register.head_dict, **head_dict} diff --git a/torch_geometric/graphgym/models/layer.py b/torch_geometric/graphgym/models/layer.py index c62fd8d66140..3a801ba0e9dc 100644 --- a/torch_geometric/graphgym/models/layer.py +++ b/torch_geometric/graphgym/models/layer.py @@ -12,6 +12,7 @@ GeneralConvLayer, GeneralEdgeConvLayer) import torch_geometric.graphgym.register as register +from torch_geometric.graphgym.register import register_layer @dataclass @@ -150,6 +151,7 @@ def forward(self, batch): # ---------- Core basic layers. Input: batch; Output: batch ----------------- # +@register_layer('linear') class Linear(nn.Module): """ Basic Linear layer. @@ -207,6 +209,7 @@ def forward(self, batch): return batch +@register_layer('mlp') class MLP(nn.Module): """ Basic MLP model. @@ -247,6 +250,7 @@ def forward(self, batch): return batch +@register_layer('gcnconv') class GCNConv(nn.Module): """ Graph Convolutional Network (GCN) layer @@ -261,6 +265,7 @@ def forward(self, batch): return batch +@register_layer('sageconv') class SAGEConv(nn.Module): """ GraphSAGE Conv layer @@ -275,6 +280,7 @@ def forward(self, batch): return batch +@register_layer('gatconv') class GATConv(nn.Module): """ Graph Attention Network (GAT) layer @@ -289,6 +295,7 @@ def forward(self, batch): return batch +@register_layer('ginconv') class GINConv(nn.Module): """ Graph Isomorphism Network (GIN) layer @@ -305,6 +312,7 @@ def forward(self, batch): return batch +@register_layer('splineconv') class SplineConv(nn.Module): """ SplineCNN layer @@ -321,6 +329,7 @@ def forward(self, batch): return batch +@register_layer('generalconv') class GeneralConv(nn.Module): """A general GNN layer""" def __init__(self, layer_config: LayerConfig, **kwargs): @@ -334,6 +343,7 @@ def forward(self, batch): return batch +@register_layer('generaledgeconv') class GeneralEdgeConv(nn.Module): """A general GNN layer that supports edge features as well""" def __init__(self, layer_config: LayerConfig, **kwargs): @@ -349,6 +359,7 @@ def forward(self, batch): return batch +@register_layer('generalsampleedgeconv') class GeneralSampleEdgeConv(nn.Module): """A general GNN layer that supports edge features and edge sampling""" def __init__(self, layer_config: LayerConfig, **kwargs): @@ -365,20 +376,3 @@ def forward(self, batch): edge_feature = batch.edge_attr[edge_mask, :] batch.x = self.model(batch.x, edge_index, edge_feature=edge_feature) return batch - - -layer_dict = { - 'linear': Linear, - 'mlp': MLP, - 'gcnconv': GCNConv, - 'sageconv': SAGEConv, - 'gatconv': GATConv, - 'splineconv': SplineConv, - 'ginconv': GINConv, - 'generalconv': GeneralConv, - 'generaledgeconv': GeneralEdgeConv, - 'generalsampleedgeconv': GeneralSampleEdgeConv, -} - -# register additional convs -register.layer_dict = {**register.layer_dict, **layer_dict} diff --git a/torch_geometric/graphgym/models/pooling.py b/torch_geometric/graphgym/models/pooling.py index 11dd93f934ce..f9dc256fa8fc 100644 --- a/torch_geometric/graphgym/models/pooling.py +++ b/torch_geometric/graphgym/models/pooling.py @@ -1,69 +1,7 @@ -from torch_scatter import scatter +from torch_geometric.nn import (global_add_pool, global_mean_pool, + global_max_pool) +from torch_geometric.graphgym.register import register_pooling -import torch_geometric.graphgym.register as register - - -def global_add_pool(x, batch, size=None): - """ - Globally pool node embeddings into graph embeddings, via elementwise sum. - Pooling function takes in node embedding [num_nodes x emb_dim] and - batch (indices) and outputs graph embedding [num_graphs x emb_dim]. - - Args: - x (torch.tensor): Input node embeddings - batch (torch.tensor): Batch tensor that indicates which node - belongs to which graph - size (optional): Total number of graphs. Can be auto-inferred. - - Returns: Pooled graph embeddings - - """ - size = batch.max().item() + 1 if size is None else size - return scatter(x, batch, dim=0, dim_size=size, reduce='add') - - -def global_mean_pool(x, batch, size=None): - """ - Globally pool node embeddings into graph embeddings, via elementwise mean. - Pooling function takes in node embedding [num_nodes x emb_dim] and - batch (indices) and outputs graph embedding [num_graphs x emb_dim]. - - Args: - x (torch.tensor): Input node embeddings - batch (torch.tensor): Batch tensor that indicates which node - belongs to which graph - size (optional): Total number of graphs. Can be auto-inferred. - - Returns: Pooled graph embeddings - - """ - size = batch.max().item() + 1 if size is None else size - return scatter(x, batch, dim=0, dim_size=size, reduce='mean') - - -def global_max_pool(x, batch, size=None): - """ - Globally pool node embeddings into graph embeddings, via elementwise max. - Pooling function takes in node embedding [num_nodes x emb_dim] and - batch (indices) and outputs graph embedding [num_graphs x emb_dim]. - - Args: - x (torch.tensor): Input node embeddings - batch (torch.tensor): Batch tensor that indicates which node - belongs to which graph - size (optional): Total number of graphs. Can be auto-inferred. - - Returns: Pooled graph embeddings - - """ - size = batch.max().item() + 1 if size is None else size - return scatter(x, batch, dim=0, dim_size=size, reduce='max') - - -pooling_dict = { - 'add': global_add_pool, - 'mean': global_mean_pool, - 'max': global_max_pool -} - -register.pooling_dict = {**register.pooling_dict, **pooling_dict} +register_pooling('add', global_add_pool) +register_pooling('mean', global_mean_pool) +register_pooling('max', global_max_pool) diff --git a/torch_geometric/graphgym/register.py b/torch_geometric/graphgym/register.py index 170bc06bf8db..9c81d0669d3b 100644 --- a/torch_geometric/graphgym/register.py +++ b/torch_geometric/graphgym/register.py @@ -1,239 +1,111 @@ -def register_base(key, module, module_dict): - """ - Base function for registering a customized module to a module dictionary +from typing import Dict, Any, Union, Callable + +act_dict: Dict[str, Any] = {} +node_encoder_dict: Dict[str, Any] = {} +edge_encoder_dict: Dict[str, Any] = {} +stage_dict: Dict[str, Any] = {} +head_dict: Dict[str, Any] = {} +layer_dict: Dict[str, Any] = {} +pooling_dict: Dict[str, Any] = {} +network_dict: Dict[str, Any] = {} +config_dict: Dict[str, Any] = {} +loader_dict: Dict[str, Any] = {} +optimizer_dict: Dict[str, Any] = {} +scheduler_dict: Dict[str, Any] = {} +loss_dict: Dict[str, Any] = {} +train_dict: Dict[str, Any] = {} + + +def register_base(mapping: Dict[str, Any], key: str, + module: Any = None) -> Union[None, Callable]: + r"""Base function for registering a module in GraphGym. Args: - key (string): Name of the module - module: PyTorch module - module_dict (dict): Python dictionary, + mapping (dict): Python dictionary to register the module. hosting all the registered modules - - """ - if key in module_dict: - raise KeyError('Key {} is already pre-defined.'.format(key)) - else: - module_dict[key] = module - - -act_dict = {} - - -def register_act(key, module): - """ - Register a customized activation function. - After registeration, the module can be directly called by GraphGym. - - Args: - key (string): Name of the module - module: PyTorch module - - """ - register_base(key, module, act_dict) - - -node_encoder_dict = {} - - -def register_node_encoder(key, module): - """ - Register a customized node feature encoder. - After registeration, the module can be directly called by GraphGym. - - Args: - key (string): Name of the module - module: PyTorch module - - """ - register_base(key, module, node_encoder_dict) - - -edge_encoder_dict = {} - - -def register_edge_encoder(key, module): - """ - Register a customized edge feature encoder. - After registeration, the module can be directly called by GraphGym. - - Args: - key (string): Name of the module - module: PyTorch module - - """ - register_base(key, module, edge_encoder_dict) - - -stage_dict = {} - - -def register_stage(key, module): - """ - Register a customized GNN stage (consists of multiple layers). - After registeration, the module can be directly called by GraphGym. - - Args: - key (string): Name of the module - module: PyTorch module - - """ - register_base(key, module, stage_dict) - - -head_dict = {} - - -def register_head(key, module): - """ - Register a customized GNN prediction head. - After registeration, the module can be directly called by GraphGym. - - Args: - key (string): Name of the module - module: PyTorch module - - """ - register_base(key, module, head_dict) - - -layer_dict = {} - - -def register_layer(key, module): + key (string): The name of the module. + module (any, optional): The module. If set to :obj:`None`, will return + a decorator to register a module. """ - Register a customized GNN layer. - After registeration, the module can be directly called by GraphGym. + if module is not None: + if key in mapping: + raise KeyError(f"Module with '{key}' already defined") + mapping[key] = module + return - Args: - key (string): Name of the module - module: PyTorch module + # Other-wise, use it as a decorator: + def wrapper(module): + register_base(mapping, key, module) + return module - """ - register_base(key, module, layer_dict) + return wrapper -pooling_dict = {} +def register_act(key: str, module: Any = None): + r"""Registers an activation function in GraphGym.""" + return register_base(act_dict, key, module) -def register_pooling(key, module): - """ - Register a customized GNN pooling layer (for graph classification). - After registeration, the module can be directly called by GraphGym. +def register_node_encoder(key: str, module: Any = None): + r"""Registers a node feature encoder in GraphGym.""" + return register_base(node_encoder_dict, key, module) - Args: - key (string): Name of the module - module: PyTorch module - """ - register_base(key, module, pooling_dict) +def register_edge_encoder(key: str, module: Any = None): + r"""Registers an edge feature encoder in GraphGym.""" + return register_base(edge_encoder_dict, key, module) -network_dict = {} +def register_stage(key: str, module: Any = None): + r"""Registers a customized GNN stage in GraphGym.""" + return register_base(stage_dict, key, module) -def register_network(key, module): - """ - Register a customized GNN model. - After registeration, the module can be directly called by GraphGym. +def register_head(key: str, module: Any = None): + r"""Registers a GNN prediction head in GraphGym.""" + return register_base(head_dict, key, module) - Args: - key (string): Name of the module - module: PyTorch module - """ - register_base(key, module, network_dict) +def register_layer(key: str, module: Any = None): + r"""Registers a GNN layer in GraphGym.""" + return register_base(layer_dict, key, module) -config_dict = {} +def register_pooling(key: str, module: Any = None): + r"""Registers a GNN global pooling/readout layer in GraphGym.""" + return register_base(pooling_dict, key, module) -def register_config(key, module): - """ - Register a customized configuration group. - After registeration, the module can be directly called by GraphGym. +def register_network(key: str, module: Any = None): + r"""Registers a GNN model in GraphGym.""" + return register_base(network_dict, key, module) - Args: - key (string): Name of the module - module: PyTorch module - """ - register_base(key, module, config_dict) - - -loader_dict = {} - - -def register_loader(key, module): - """ - Register a customized PyG data loader. - After registeration, the module can be directly called by GraphGym. +def register_config(key: str, module: Any = None): + r"""Registers a configuration group in GraphGym.""" + return register_base(config_dict, key, module) - Args: - key (string): Name of the module - module: PyTorch module - - """ - register_base(key, module, loader_dict) +def register_loader(key: str, module: Any = None): + r"""Registers a data loader in GraphGym.""" + return register_base(loader_dict, key, module) -optimizer_dict = {} +def register_optimizer(key: str, module: Any = None): + r"""Registers an optimizer in GraphGym.""" + return register_base(optimizer_dict, key, module) -def register_optimizer(key, module): - """ - Register a customized optimizer. - After registeration, the module can be directly called by GraphGym. - Args: - key (string): Name of the module - module: PyTorch module +def register_scheduler(key: str, module: Any = None): + r"""Registers a learning rate scheduler in GraphGym.""" + return register_base(scheduler_dict, key, module) - """ - register_base(key, module, optimizer_dict) +def register_loss(key: str, module: Any = None): + r"""Registers a loss function in GraphGym.""" + return register_base(loss_dict, key, module) -scheduler_dict = {} - -def register_scheduler(key, module): - """ - Register a customized learning rate scheduler. - After registeration, the module can be directly called by GraphGym. - - Args: - key (string): Name of the module - module: PyTorch module - - """ - register_base(key, module, scheduler_dict) - - -loss_dict = {} - - -def register_loss(key, module): - """ - Register a customized loss function. - After registeration, the module can be directly called by GraphGym. - - Args: - key (string): Name of the module - module: PyTorch module - - """ - register_base(key, module, loss_dict) - - -train_dict = {} - - -def register_train(key, module): - """ - Register a customized training function. - After registeration, the module can be directly called by GraphGym. - - Args: - key (string): Name of the module - module: PyTorch module - - """ - register_base(key, module, train_dict) +def register_train(key: str, module: Any = None): + r"""Registers a training function in GraphGym.""" + return register_base(train_dict, key, module)