From 15dbdaffb87dab083896ca01f32441878be11d0a Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Thu, 17 Nov 2022 09:49:15 +0100 Subject: [PATCH] Check `ONNX` output equality (#5997) --- CHANGELOG.md | 2 +- test/nn/models/test_basic_gnn.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c44b695ef24f..61cef6ae0cc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888)) - Added TorchScript support for `AttentiveFP `([#5868](https://github.com/pyg-team/pytorch_geometric/pull/5868)) - Added `num_steps` argument to training and inference benchmarks ([#5898](https://github.com/pyg-team/pytorch_geometric/pull/5898)) -- Added `torch.onnx.export` support ([#5877](https://github.com/pyg-team/pytorch_geometric/pull/5877)) +- Added `torch.onnx.export` support ([#5877](https://github.com/pyg-team/pytorch_geometric/pull/5877), [#5997](https://github.com/pyg-team/pytorch_geometric/pull/5997)) - Enable VTune ITT in inference and training benchmarks ([#5830](https://github.com/pyg-team/pytorch_geometric/pull/5830), [#5878](https://github.com/pyg-team/pytorch_geometric/pull/5878)) - Add training benchmark ([#5774](https://github.com/pyg-team/pytorch_geometric/pull/5774)) - Added a "Link Prediction on MovieLens" Colab notebook ([#5823](https://github.com/pyg-team/pytorch_geometric/pull/5823)) diff --git a/test/nn/models/test_basic_gnn.py b/test/nn/models/test_basic_gnn.py index d002a0783967..da24bc943354 100644 --- a/test/nn/models/test_basic_gnn.py +++ b/test/nn/models/test_basic_gnn.py @@ -184,10 +184,12 @@ def forward(self, x, edge_index): model = MyModel() x = torch.randn(3, 8) - edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]]) + expected = model(x, edge_index) + assert expected.size() == (3, 16) torch.onnx.export(model, (x, edge_index), 'model.onnx', - input_names=('x', 'edge_index')) + input_names=('x', 'edge_index'), opset_version=16) model = onnx.load('model.onnx') onnx.checker.check_model(model) @@ -198,6 +200,7 @@ def forward(self, x, edge_index): 'x': x.numpy(), 'edge_index': edge_index.numpy() })[0] - assert out.shape == (3, 16) + out = torch.from_numpy(out) + assert torch.allclose(out, expected, atol=1e-6) os.remove('model.onnx')