Skip to content

Commit

Permalink
Merge pull request #129 from ai4co/nar refer to original pull request #…
Browse files Browse the repository at this point in the history
…122

[Feat] Nonautoregressive Methods Resolve Conflicts from #122
  • Loading branch information
cbhua authored Mar 8, 2024
2 parents 416bf38 + ccb2a24 commit fd58215
Show file tree
Hide file tree
Showing 14 changed files with 788 additions and 18 deletions.
6 changes: 6 additions & 0 deletions rl4co/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
AutoregressivePolicy,
GraphAttentionEncoder,
)
from rl4co.models.zoo.common.nonautoregressive import (
NonAutoregressiveDecoder,
NonAutoregressiveEncoder,
NonAutoregressiveModel,
NonAutoregressivePolicy,
)
from rl4co.models.zoo.common.search import SearchBase
from rl4co.models.zoo.eas import EAS, EASEmb, EASLay
from rl4co.models.zoo.ham import (
Expand Down
13 changes: 6 additions & 7 deletions rl4co/models/nn/dec_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ def __init__(
self.select_start_nodes_fn = select_start_nodes_fn

def _step(
self, logp: torch.Tensor, td: TensorDict, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Main decoding operation. This method should be implemented by subclasses."""
self, logp: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]:
raise NotImplementedError("Must be implemented by subclass")

def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase):
Expand Down Expand Up @@ -96,8 +95,9 @@ def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase):

return td, env, self.num_starts

def post_decoder_hook(self, td, env):
"""Post decoding hook. This method is called after the main decoding operation."""
def post_decoder_hook(
self, td: TensorDict, env: RL4COEnvBase
) -> Tuple[torch.Tensor, torch.Tensor, TensorDict, RL4COEnvBase]:
assert (
len(self.logp) > 0
), "No outputs were collected because all environments were done. Check your initial state"
Expand All @@ -106,8 +106,7 @@ def post_decoder_hook(self, td, env):

def step(
self, logp: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]:
"""Main decoding operation. This method calls the :meth:`_step` method and collects the outputs."""
) -> TensorDict:
assert not logp.isinf().all(1).any()

logp, selected_actions, td = self._step(logp, mask, td, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions rl4co/models/nn/env_embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from rl4co.models.nn.env_embeddings.context import env_context_embedding
from rl4co.models.nn.env_embeddings.dynamic import env_dynamic_embedding
from rl4co.models.nn.env_embeddings.edge import env_edge_embedding
from rl4co.models.nn.env_embeddings.init import env_init_embedding
152 changes: 152 additions & 0 deletions rl4co/models/nn/env_embeddings/edge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import torch
import torch.nn as nn

from torch import Tensor

try:
from torch_geometric.data import Batch, Data
except ImportError:
Batch = Data = None

from rl4co.utils.ops import get_distance_matrix, get_full_graph_edge_index, sparsify_graph
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


def env_edge_embedding(env_name: str, config: dict) -> nn.Module:
"""Retrieve the edge embedding module specific to the environment. Edge embeddings are crucial for
transforming the raw edge features into a format suitable for the neural network, especially in
graph neural networks where edge features can significantly impact the model's performance.
Args:
env: Environment or its name.
config: A dictionary of configuration options for the environment.
"""
embedding_registry = {
"tsp": TSPEdgeEmbedding,
"atsp": ATSPEdgeEmbedding,
"cvrp": TSPEdgeEmbedding,
"sdvrp": TSPEdgeEmbedding,
"pctsp": TSPEdgeEmbedding,
"spctsp": TSPEdgeEmbedding,
"op": TSPEdgeEmbedding,
"dpp": TSPEdgeEmbedding,
"mdpp": TSPEdgeEmbedding,
"pdp": TSPEdgeEmbedding,
"mtsp": TSPEdgeEmbedding,
"smtwtp": NoEdgeEmbedding,
}

if env_name not in embedding_registry:
raise ValueError(
f"Unknown environment name '{env_name}'. Available init embeddings: {embedding_registry.keys()}"
)

return embedding_registry[env_name](**config)


class TSPEdgeEmbedding(nn.Module):
"""Edge embedding module for the Traveling Salesman Problem (TSP) and related problems.
This module converts the cost matrix or the distances between nodes into embeddings that can be
used by the neural network. It supports sparsification to focus on a subset of relevant edges,
which is particularly useful for large graphs.
"""

def __init__(
self,
embedding_dim,
linear_bias=True,
sparsify=True,
k_sparse: int = None,
):
assert Batch is not None, (
"torch_geometric not found. Please install torch_geometric using instructions from "
"https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html."
)

super(TSPEdgeEmbedding, self).__init__()
node_dim = 1
self.k_sparse = k_sparse
self.sparsify = sparsify
self.edge_embed = nn.Linear(node_dim, embedding_dim, linear_bias)

def forward(self, td, init_embeddings: Tensor):
cost_matrix = get_distance_matrix(td["locs"])
batch = self._cost_matrix_to_graph(cost_matrix, init_embeddings)
return batch

def _cost_matrix_to_graph(self, batch_cost_matrix: Tensor, init_embeddings: Tensor):
"""Convert batched cost_matrix to batched PyG graph, and calculate edge embeddings.
Args:
batch_cost_matrix: Tensor of shape [batch_size, n, n]
init_embedding: init embeddings
"""
graph_data = []
for index, cost_matrix in enumerate(batch_cost_matrix):
if self.sparsify:
edge_index, edge_attr = sparsify_graph(
cost_matrix, self.k_sparse, self_loop=False
)
else:
edge_index = get_full_graph_edge_index(
cost_matrix.shape[0], self_loop=False
).to(cost_matrix.device)
edge_attr = cost_matrix[edge_index[0], edge_index[1]]

graph = Data(
x=init_embeddings[index],
edge_index=edge_index,
edge_attr=edge_attr,
)
graph_data.append(graph)

batch = Batch.from_data_list(graph_data)
batch.edge_attr = self.edge_embed(batch.edge_attr)
return batch


class ATSPEdgeEmbedding(TSPEdgeEmbedding):
"""Edge embedding module for the Asymmetric Traveling Salesman Problem (ATSP).
Inherits from TSPEdgeEmbedding and adapts the edge embedding process to handle
asymmetric cost matrices, where the cost from node i to node j may not be the same as from j to i.
"""

def forward(self, td, init_embeddings: Tensor):
batch = self._cost_matrix_to_graph(td["cost_matrix"], init_embeddings)
return batch


class NoEdgeEmbedding(nn.Module):
"""A module for environments that do not require edge embeddings, or where edge features
are not used. This can be useful for simplifying models in problems where only node
features are relevant.
"""

def __init__(self, embedding_dim, self_loop=False, **kwargs):
assert Batch is not None, (
"torch_geometric not found. Please install torch_geometric using instructions from "
"https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html."
)

super(NoEdgeEmbedding, self).__init__()
self.embedding_dim = embedding_dim
self.self_loop = self_loop

def forward(self, td, init_embeddings: Tensor):
data_list = []
n = init_embeddings.shape[1]
device = init_embeddings.device
edge_index = get_full_graph_edge_index(n, self_loop=self.self_loop).to(device)

for node_embed in init_embeddings:
data = Data(
x=node_embed,
edge_index=edge_index,
edge_attr=torch.zeros((n, self.embedding_dim), device=device),
)
data_list.append(data)

batch = Batch.from_data_list(data_list)
return batch
15 changes: 5 additions & 10 deletions rl4co/models/nn/graph/gcn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

Expand All @@ -10,6 +9,7 @@
from torch_geometric.nn import GCNConv

from rl4co.models.nn.env_embeddings import env_init_embedding
from rl4co.utils.ops import get_full_graph_edge_index
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand Down Expand Up @@ -47,10 +47,7 @@ def __init__(
)

# Generate edge index for a fully connected graph
adj_matrix = torch.ones(num_nodes, num_nodes)
if self_loop:
adj_matrix.fill_diagonal_(0) # No self-loops
self.edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0))
self.edge_index = get_full_graph_edge_index(num_nodes, self_loop)

# Define the GCN layers
self.gcn_layers = nn.ModuleList(
Expand Down Expand Up @@ -82,11 +79,9 @@ def forward(

# Check to update the edge index with different number of node
if num_node != self.edge_index.max().item() + 1:
adj_matrix = torch.ones(num_node, num_node)
if self.self_loop:
adj_matrix.fill_diagonal_(0)
edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0))
edge_index = edge_index.to(init_h.device)
edge_index = get_full_graph_edge_index(num_node, self.self_loop).to(
init_h.device
)
else:
edge_index = self.edge_index.to(init_h.device)

Expand Down
99 changes: 99 additions & 0 deletions rl4co/models/nn/graph/gnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
import torch.nn as nn

try:
import torch_geometric.nn as gnn
except ImportError:
gnn = None

from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


class GNNLayer(nn.Module):
"""Graph Neural Network Layer for processing graph structures.
Args:
units: The number of units in each linear transformation layer.
act_fn: The name of the activation function to use after each linear layer. Defaults to 'silu'.
agg_fn: The name of the global aggregation function to use for pooling features across the graph. Defaults to 'mean'.
"""

def __init__(self, units: int, act_fn: str = "silu", agg_fn: str = "mean"):
assert gnn is not None, (
"torch_geometric not found. Please install torch_geometric using instructions from "
"https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html."
)

super(GNNLayer, self).__init__()
self.units = units
self.act_fn = getattr(nn.functional, act_fn)
self.agg_fn = getattr(gnn, f"global_{agg_fn}_pool")

# Vertex updates
self.v_lin1 = nn.Linear(units, units)
self.v_lin2 = nn.Linear(units, units)
self.v_lin3 = nn.Linear(units, units)
self.v_lin4 = nn.Linear(units, units)
self.v_bn = gnn.BatchNorm(units)

# Edge updates
self.e_lin = nn.Linear(units, units)
self.e_bn = gnn.BatchNorm(units)

def forward(self, x, edge_index, edge_attr):
x0 = x
w0 = w = edge_attr

# Vertex updates
x1 = self.v_lin1(x0)
x2 = self.v_lin2(x0)
x3 = self.v_lin3(x0)
x4 = self.v_lin4(x0)
x = x0 + self.act_fn(
self.v_bn(
x1 + self.agg_fn(torch.sigmoid(w0) * x2[edge_index[1]], edge_index[0])
)
)

# Edge updates
w1 = self.e_lin(w0)
w = w0 + self.act_fn(self.e_bn(w1 + x3[edge_index[0]] + x4[edge_index[1]]))
return x, w


class GNNEncoder(nn.Module):
"""Anisotropic Graph Neural Network encoder with edge-gating mechanism as in Joshi et al. (2022)
Args:
num_layers: The number of GNN layers to stack in the network.
embedding_dim: The dimensionality of the embeddings for each node in the graph.
act_fn: The activation function to use in each GNNLayer, see https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions for available options. Defaults to 'silu'.
agg_fn: The aggregation function to use in each GNNLayer for pooling features. Options: 'add', 'mean', 'max'. Defaults to 'mean'.
"""

def __init__(self, num_layers: int, embedding_dim: int, act_fn="silu", agg_fn="mean"):
super(GNNEncoder, self).__init__()
self.act_fn = getattr(nn.functional, act_fn)
self.agg_fn = agg_fn

# Stack of GNN layers
self.layers = nn.ModuleList(
[GNNLayer(embedding_dim, act_fn, agg_fn) for _ in range(num_layers)]
)

def forward(self, x, edge_index, w):
"""Sequentially passes the input graph data through the stacked GNN layers,
applying specified transformations and aggregations to learn graph representations.
Args:
x: The node features of the graph with shape [num_nodes, embedding_dim].
edge_index: The edge indices of the graph with shape [2, num_edges].
w: The edge attributes or weights with shape [num_edges, embedding_dim].
"""
x = self.act_fn(x)
w = self.act_fn(w)
for layer in self.layers:
x, w = layer(x, edge_index, w)
return x, w
4 changes: 4 additions & 0 deletions rl4co/models/zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from rl4co.models.zoo.active_search import ActiveSearch
from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy
from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
from rl4co.models.zoo.common.nonautoregressive import (
NonAutoregressiveModel,
NonAutoregressivePolicy,
)
from rl4co.models.zoo.common.search import SearchBase
from rl4co.models.zoo.eas import EAS, EASEmb, EASLay
from rl4co.models.zoo.ham import (
Expand Down
4 changes: 4 additions & 0 deletions rl4co/models/zoo/common/nonautoregressive/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from rl4co.models.zoo.common.nonautoregressive.decoder import NonAutoregressiveDecoder
from rl4co.models.zoo.common.nonautoregressive.encoder import NonAutoregressiveEncoder
from rl4co.models.zoo.common.nonautoregressive.model import NonAutoregressiveModel
from rl4co.models.zoo.common.nonautoregressive.policy import NonAutoregressivePolicy
Loading

0 comments on commit fd58215

Please sign in to comment.