You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am having an issue where a graph embedding for the same graph comes out different when using InstanceNorm as part of the DeepConvLayer wrapper. The problem relates to whether or not track_running_stats is set to False (which is when the error occurs), vs. when it is set to True when the error doesn't seem to occur. I'm not sure if I have a misunderstanding of how instance normalisation should work in a graph but I thoguht that for a feature vector of a single graph, we would be normalising the first feature by the average of that feature for all nodes in the graph. So, this should be unaffected by what else is in the batch. But maybe I am wrong in thinking this? either way, I think there is either a bug in the code or my expectation of how this should work is wrong. Here's a script that should reproduce the error, you can switch the track_running_stats as a model argument to see the different results.
importtorchfromtorch_geometric.dataimportDatafromtorch_geometric.loaderimportDataLoaderimportrandomimportnumpyasnpimporttorch.nnasnnfromtorch_geometric.nn.normimportInstanceNormfromtorch_geometric.nn.poolimportglobal_mean_poolfromtorch_geometric.nn.convimportRGCNConvfromtorch_geometric.nn.modelsimportDeepGCNLayerrandom.seed(10)
np.random.seed(10)
torch.random.manual_seed(10)
classSimpleDeepGCN(torch.nn.Module):
def__init__(self, in_channels, hidden_channels, out_channels, num_relations, num_layers=3,
track_running_stats=True):
super(SimpleDeepGCN, self).__init__()
self.node_encoder=nn.Linear(in_channels, hidden_channels)
self.layers=nn.ModuleList()
foriinrange(num_layers):
conv=RGCNConv(hidden_channels, hidden_channels, num_relations)
norm=InstanceNorm(in_channels=hidden_channels, track_running_stats=track_running_stats, affine=True)
act=nn.ReLU()
layer=DeepGCNLayer(conv, norm, act, block='res')
self.layers.append(layer)
self.out_layer=nn.Linear(hidden_channels, out_channels)
defforward(self, batch):
x, edge_index, edge_type=batch.x, batch.edge_index, batch.edge_typex=self.node_encoder(x)
forlayerinself.layers:
x=layer(x, edge_index, edge_type)
x=self.out_layer(x)
returnglobal_mean_pool(x, batch.batch)
defcreate_relational_graph(num_nodes, num_edges, num_node_features, num_edge_types):
x=torch.randn(num_nodes, num_node_features)
edge_index=torch.randint(0, num_nodes, (2, num_edges))
edge_type=torch.randint(0, num_edge_types, (num_edges,))
returnData(x=x, edge_index=edge_index, edge_type=edge_type, binary_hash=f"{np.random.randint(0, 10000)}")
defcreate_batch_graphs(num_graphs, num_node_features, num_edge_types):
num_nodes=np.random.randint(50, 100)
num_edges=np.random.randint(100, 200)
return [create_relational_graph(num_nodes, num_edges, num_node_features, num_edge_types)
for_inrange(num_graphs)]
# Create a fixed graph for testingfixed_graph=create_relational_graph(num_nodes=20, num_edges=40, num_node_features=256, num_edge_types=3)
# Create two batches of random graphsbatch1=create_batch_graphs(num_graphs=5, num_node_features=256, num_edge_types=3)
batch2=create_batch_graphs(num_graphs=8, num_node_features=256, num_edge_types=3)
# Append fixed graph to each batchbatch1.append(fixed_graph)
batch2.append(fixed_graph)
# Create data loadersloader1=DataLoader(batch1, batch_size=len(batch1), shuffle=False)
loader2=DataLoader(batch2, batch_size=len(batch2), shuffle=False)
model=SimpleDeepGCN(256, 1000, 128, 3, track_running_stats=True)
defcompare_embeddings(emb1, emb2, rtol=1e-5, atol=1e-8):
""" Compare two numpy arrays of embeddings. Args: emb1, emb2: numpy arrays of the same shape rtol: relative tolerance parameter atol: absolute tolerance parameter Returns: bool: True if the arrays are equal within the given tolerance, False otherwise """returnnp.allclose(emb1, emb2, rtol=rtol, atol=atol)
# Test the embeddingsmodel.eval()
withtorch.no_grad():
forbatch1, batch2inzip(loader1, loader2):
# Forward pass for batch1output1=model.forward(batch1).cpu().numpy()
fixed_graph_emb1=output1[-1] # Assuming the fixed graph is at the end# Forward pass for batch2output2=model.forward(batch2).cpu().numpy()
fixed_graph_emb2=output2[-1] # Assuming the fixed graph is at the end# Compare the embeddingsare_embeddings_equal=compare_embeddings(fixed_graph_emb1, fixed_graph_emb2)
print(f"Are the embeddings of the fixed graph equal in both batches? {are_embeddings_equal}")
Versions
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 23357 100 23357 0 0 94579 0 --:--:-- --:--:-- --:--:-- 94947
Collecting environment information...
PyTorch version: 2.1.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.6.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.1
Libc version: N/A
Python version: 3.10.14 (main, May 6 2024, 14:42:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M3 Pro
Versions of relevant libraries:
[pip3] numpy==1.26.1
[pip3] pytorch-metric-learning==2.4.1
[pip3] torch==2.1.0
[pip3] torch_cluster==1.6.3
[pip3] torch_geometric==2.5.3
[pip3] torch_scatter==2.1.2
[pip3] torch_sparse==0.6.18
[pip3] torch_spline_conv==1.2.2
[pip3] torchaudio==2.1.0
[pip3] torchmetrics==1.3.0
[conda] numpy 1.26.1 pypi_0 pypi
[conda] pytorch-metric-learning 2.4.1 pypi_0 pypi
[conda] torch 2.1.0 pypi_0 pypi
[conda] torch-cluster 1.6.3 pypi_0 pypi
[conda] torch-geometric 2.5.3 pypi_0 pypi
[conda] torch-scatter 2.1.2 pypi_0 pypi
[conda] torch-sparse 0.6.18 pypi_0 pypi
[conda] torch-spline-conv 1.2.2 pypi_0 pypi
[conda] torchaudio 2.1.0 pypi_0 pypi
[conda] torchmetrics 1.3.0 pypi_0 pypi
The text was updated successfully, but these errors were encountered:
🐛 Describe the bug
I am having an issue where a graph embedding for the same graph comes out different when using
InstanceNorm
as part of theDeepConvLayer
wrapper. The problem relates to whether or nottrack_running_stats
is set toFalse
(which is when the error occurs), vs. when it is set toTrue
when the error doesn't seem to occur. I'm not sure if I have a misunderstanding of how instance normalisation should work in a graph but I thoguht that for a feature vector of a single graph, we would be normalising the first feature by the average of that feature for all nodes in the graph. So, this should be unaffected by what else is in the batch. But maybe I am wrong in thinking this? either way, I think there is either a bug in the code or my expectation of how this should work is wrong. Here's a script that should reproduce the error, you can switch thetrack_running_stats
as a model argument to see the different results.Versions
The text was updated successfully, but these errors were encountered: