Skip to content

Commit

Permalink
Update exphormer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
phoeenniixx committed Nov 16, 2024
1 parent ebd2bf3 commit 39ee4f9
Showing 1 changed file with 29 additions and 8 deletions.
37 changes: 29 additions & 8 deletions torch_geometric/transforms/exphormer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import torch.nn as nn

Expand Down Expand Up @@ -45,26 +47,45 @@ def __init__(self, hidden_dim: int, num_layers: int = 3,
self.dropout = nn.Dropout(dropout)

def forward(self, data: Data) -> torch.Tensor:
if data.x.size(0) == 0:

if data.x is None:
raise ValueError("Input data.x cannot be None")
x: torch.Tensor = data.x

if x.size(0) == 0:
raise ValueError("Input graph is empty.")

if not hasattr(data, 'edge_index') or data.edge_index is None:
raise ValueError(
"Input data must contain 'edge_index' for message passing.")
x, edge_index = data.x, data.edge_index
edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
edge_index: torch.Tensor = data.edge_index

edge_attr: Optional[torch.Tensor] = data.edge_attr if hasattr(
data, 'edge_attr') else None

batch_size = x.size(0)

if self.virtual_node_transform is not None:
data = self.virtual_node_transform(data)
x, edge_index = data.x, data.edge_index
if data.x is None:
raise ValueError("Virtual node transform resulted in None x")
x = data.x
if data.edge_index is None:
raise ValueError(
"Virtual node transform resulted in None edge_index")
edge_index = data.edge_index

for layer in self.layers:
residual = x
local_out = layer['local'](x, edge_index, edge_attr)
expander_out = 0
expander_out = torch.zeros_like(x)
if self.use_expander and layer['expander'] is not None:
expander_out, _ = layer['expander'](
x[:batch_size + self.num_virtual_nodes],
batch_size + self.num_virtual_nodes)

node_subset = x[:batch_size + self.num_virtual_nodes]
expander_out, _ = layer['expander'](node_subset, batch_size +
self.num_virtual_nodes)

x = layer['layer_norm'](residual + local_out + expander_out)
x = x + layer['ffn'](x)

return x[:batch_size]

0 comments on commit 39ee4f9

Please sign in to comment.