-
Notifications
You must be signed in to change notification settings - Fork 88
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
Showing
14 changed files
with
788 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from rl4co.models.nn.env_embeddings.context import env_context_embedding | ||
from rl4co.models.nn.env_embeddings.dynamic import env_dynamic_embedding | ||
from rl4co.models.nn.env_embeddings.edge import env_edge_embedding | ||
from rl4co.models.nn.env_embeddings.init import env_init_embedding |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from torch import Tensor | ||
|
||
try: | ||
from torch_geometric.data import Batch, Data | ||
except ImportError: | ||
Batch = Data = None | ||
|
||
from rl4co.utils.ops import get_distance_matrix, get_full_graph_edge_index, sparsify_graph | ||
from rl4co.utils.pylogger import get_pylogger | ||
|
||
log = get_pylogger(__name__) | ||
|
||
|
||
def env_edge_embedding(env_name: str, config: dict) -> nn.Module: | ||
"""Retrieve the edge embedding module specific to the environment. Edge embeddings are crucial for | ||
transforming the raw edge features into a format suitable for the neural network, especially in | ||
graph neural networks where edge features can significantly impact the model's performance. | ||
Args: | ||
env: Environment or its name. | ||
config: A dictionary of configuration options for the environment. | ||
""" | ||
embedding_registry = { | ||
"tsp": TSPEdgeEmbedding, | ||
"atsp": ATSPEdgeEmbedding, | ||
"cvrp": TSPEdgeEmbedding, | ||
"sdvrp": TSPEdgeEmbedding, | ||
"pctsp": TSPEdgeEmbedding, | ||
"spctsp": TSPEdgeEmbedding, | ||
"op": TSPEdgeEmbedding, | ||
"dpp": TSPEdgeEmbedding, | ||
"mdpp": TSPEdgeEmbedding, | ||
"pdp": TSPEdgeEmbedding, | ||
"mtsp": TSPEdgeEmbedding, | ||
"smtwtp": NoEdgeEmbedding, | ||
} | ||
|
||
if env_name not in embedding_registry: | ||
raise ValueError( | ||
f"Unknown environment name '{env_name}'. Available init embeddings: {embedding_registry.keys()}" | ||
) | ||
|
||
return embedding_registry[env_name](**config) | ||
|
||
|
||
class TSPEdgeEmbedding(nn.Module): | ||
"""Edge embedding module for the Traveling Salesman Problem (TSP) and related problems. | ||
This module converts the cost matrix or the distances between nodes into embeddings that can be | ||
used by the neural network. It supports sparsification to focus on a subset of relevant edges, | ||
which is particularly useful for large graphs. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
embedding_dim, | ||
linear_bias=True, | ||
sparsify=True, | ||
k_sparse: int = None, | ||
): | ||
assert Batch is not None, ( | ||
"torch_geometric not found. Please install torch_geometric using instructions from " | ||
"https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html." | ||
) | ||
|
||
super(TSPEdgeEmbedding, self).__init__() | ||
node_dim = 1 | ||
self.k_sparse = k_sparse | ||
self.sparsify = sparsify | ||
self.edge_embed = nn.Linear(node_dim, embedding_dim, linear_bias) | ||
|
||
def forward(self, td, init_embeddings: Tensor): | ||
cost_matrix = get_distance_matrix(td["locs"]) | ||
batch = self._cost_matrix_to_graph(cost_matrix, init_embeddings) | ||
return batch | ||
|
||
def _cost_matrix_to_graph(self, batch_cost_matrix: Tensor, init_embeddings: Tensor): | ||
"""Convert batched cost_matrix to batched PyG graph, and calculate edge embeddings. | ||
Args: | ||
batch_cost_matrix: Tensor of shape [batch_size, n, n] | ||
init_embedding: init embeddings | ||
""" | ||
graph_data = [] | ||
for index, cost_matrix in enumerate(batch_cost_matrix): | ||
if self.sparsify: | ||
edge_index, edge_attr = sparsify_graph( | ||
cost_matrix, self.k_sparse, self_loop=False | ||
) | ||
else: | ||
edge_index = get_full_graph_edge_index( | ||
cost_matrix.shape[0], self_loop=False | ||
).to(cost_matrix.device) | ||
edge_attr = cost_matrix[edge_index[0], edge_index[1]] | ||
|
||
graph = Data( | ||
x=init_embeddings[index], | ||
edge_index=edge_index, | ||
edge_attr=edge_attr, | ||
) | ||
graph_data.append(graph) | ||
|
||
batch = Batch.from_data_list(graph_data) | ||
batch.edge_attr = self.edge_embed(batch.edge_attr) | ||
return batch | ||
|
||
|
||
class ATSPEdgeEmbedding(TSPEdgeEmbedding): | ||
"""Edge embedding module for the Asymmetric Traveling Salesman Problem (ATSP). | ||
Inherits from TSPEdgeEmbedding and adapts the edge embedding process to handle | ||
asymmetric cost matrices, where the cost from node i to node j may not be the same as from j to i. | ||
""" | ||
|
||
def forward(self, td, init_embeddings: Tensor): | ||
batch = self._cost_matrix_to_graph(td["cost_matrix"], init_embeddings) | ||
return batch | ||
|
||
|
||
class NoEdgeEmbedding(nn.Module): | ||
"""A module for environments that do not require edge embeddings, or where edge features | ||
are not used. This can be useful for simplifying models in problems where only node | ||
features are relevant. | ||
""" | ||
|
||
def __init__(self, embedding_dim, self_loop=False, **kwargs): | ||
assert Batch is not None, ( | ||
"torch_geometric not found. Please install torch_geometric using instructions from " | ||
"https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html." | ||
) | ||
|
||
super(NoEdgeEmbedding, self).__init__() | ||
self.embedding_dim = embedding_dim | ||
self.self_loop = self_loop | ||
|
||
def forward(self, td, init_embeddings: Tensor): | ||
data_list = [] | ||
n = init_embeddings.shape[1] | ||
device = init_embeddings.device | ||
edge_index = get_full_graph_edge_index(n, self_loop=self.self_loop).to(device) | ||
|
||
for node_embed in init_embeddings: | ||
data = Data( | ||
x=node_embed, | ||
edge_index=edge_index, | ||
edge_attr=torch.zeros((n, self.embedding_dim), device=device), | ||
) | ||
data_list.append(data) | ||
|
||
batch = Batch.from_data_list(data_list) | ||
return batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
try: | ||
import torch_geometric.nn as gnn | ||
except ImportError: | ||
gnn = None | ||
|
||
from rl4co.utils.pylogger import get_pylogger | ||
|
||
log = get_pylogger(__name__) | ||
|
||
|
||
class GNNLayer(nn.Module): | ||
"""Graph Neural Network Layer for processing graph structures. | ||
Args: | ||
units: The number of units in each linear transformation layer. | ||
act_fn: The name of the activation function to use after each linear layer. Defaults to 'silu'. | ||
agg_fn: The name of the global aggregation function to use for pooling features across the graph. Defaults to 'mean'. | ||
""" | ||
|
||
def __init__(self, units: int, act_fn: str = "silu", agg_fn: str = "mean"): | ||
assert gnn is not None, ( | ||
"torch_geometric not found. Please install torch_geometric using instructions from " | ||
"https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html." | ||
) | ||
|
||
super(GNNLayer, self).__init__() | ||
self.units = units | ||
self.act_fn = getattr(nn.functional, act_fn) | ||
self.agg_fn = getattr(gnn, f"global_{agg_fn}_pool") | ||
|
||
# Vertex updates | ||
self.v_lin1 = nn.Linear(units, units) | ||
self.v_lin2 = nn.Linear(units, units) | ||
self.v_lin3 = nn.Linear(units, units) | ||
self.v_lin4 = nn.Linear(units, units) | ||
self.v_bn = gnn.BatchNorm(units) | ||
|
||
# Edge updates | ||
self.e_lin = nn.Linear(units, units) | ||
self.e_bn = gnn.BatchNorm(units) | ||
|
||
def forward(self, x, edge_index, edge_attr): | ||
x0 = x | ||
w0 = w = edge_attr | ||
|
||
# Vertex updates | ||
x1 = self.v_lin1(x0) | ||
x2 = self.v_lin2(x0) | ||
x3 = self.v_lin3(x0) | ||
x4 = self.v_lin4(x0) | ||
x = x0 + self.act_fn( | ||
self.v_bn( | ||
x1 + self.agg_fn(torch.sigmoid(w0) * x2[edge_index[1]], edge_index[0]) | ||
) | ||
) | ||
|
||
# Edge updates | ||
w1 = self.e_lin(w0) | ||
w = w0 + self.act_fn(self.e_bn(w1 + x3[edge_index[0]] + x4[edge_index[1]])) | ||
return x, w | ||
|
||
|
||
class GNNEncoder(nn.Module): | ||
"""Anisotropic Graph Neural Network encoder with edge-gating mechanism as in Joshi et al. (2022) | ||
Args: | ||
num_layers: The number of GNN layers to stack in the network. | ||
embedding_dim: The dimensionality of the embeddings for each node in the graph. | ||
act_fn: The activation function to use in each GNNLayer, see https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions for available options. Defaults to 'silu'. | ||
agg_fn: The aggregation function to use in each GNNLayer for pooling features. Options: 'add', 'mean', 'max'. Defaults to 'mean'. | ||
""" | ||
|
||
def __init__(self, num_layers: int, embedding_dim: int, act_fn="silu", agg_fn="mean"): | ||
super(GNNEncoder, self).__init__() | ||
self.act_fn = getattr(nn.functional, act_fn) | ||
self.agg_fn = agg_fn | ||
|
||
# Stack of GNN layers | ||
self.layers = nn.ModuleList( | ||
[GNNLayer(embedding_dim, act_fn, agg_fn) for _ in range(num_layers)] | ||
) | ||
|
||
def forward(self, x, edge_index, w): | ||
"""Sequentially passes the input graph data through the stacked GNN layers, | ||
applying specified transformations and aggregations to learn graph representations. | ||
Args: | ||
x: The node features of the graph with shape [num_nodes, embedding_dim]. | ||
edge_index: The edge indices of the graph with shape [2, num_edges]. | ||
w: The edge attributes or weights with shape [num_edges, embedding_dim]. | ||
""" | ||
x = self.act_fn(x) | ||
w = self.act_fn(w) | ||
for layer in self.layers: | ||
x, w = layer(x, edge_index, w) | ||
return x, w |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from rl4co.models.zoo.common.nonautoregressive.decoder import NonAutoregressiveDecoder | ||
from rl4co.models.zoo.common.nonautoregressive.encoder import NonAutoregressiveEncoder | ||
from rl4co.models.zoo.common.nonautoregressive.model import NonAutoregressiveModel | ||
from rl4co.models.zoo.common.nonautoregressive.policy import NonAutoregressivePolicy |
Oops, something went wrong.