Skip to content

Commit

Permalink
Check ONNX output equality (#5997)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 17, 2022
1 parent 3d8ca8b commit 15dbdaf
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888))
- Added TorchScript support for `AttentiveFP `([#5868](https://github.com/pyg-team/pytorch_geometric/pull/5868))
- Added `num_steps` argument to training and inference benchmarks ([#5898](https://github.com/pyg-team/pytorch_geometric/pull/5898))
- Added `torch.onnx.export` support ([#5877](https://github.com/pyg-team/pytorch_geometric/pull/5877))
- Added `torch.onnx.export` support ([#5877](https://github.com/pyg-team/pytorch_geometric/pull/5877), [#5997](https://github.com/pyg-team/pytorch_geometric/pull/5997))
- Enable VTune ITT in inference and training benchmarks ([#5830](https://github.com/pyg-team/pytorch_geometric/pull/5830), [#5878](https://github.com/pyg-team/pytorch_geometric/pull/5878))
- Add training benchmark ([#5774](https://github.com/pyg-team/pytorch_geometric/pull/5774))
- Added a "Link Prediction on MovieLens" Colab notebook ([#5823](https://github.com/pyg-team/pytorch_geometric/pull/5823))
Expand Down
9 changes: 6 additions & 3 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,12 @@ def forward(self, x, edge_index):

model = MyModel()
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]])
expected = model(x, edge_index)
assert expected.size() == (3, 16)

torch.onnx.export(model, (x, edge_index), 'model.onnx',
input_names=('x', 'edge_index'))
input_names=('x', 'edge_index'), opset_version=16)

model = onnx.load('model.onnx')
onnx.checker.check_model(model)
Expand All @@ -198,6 +200,7 @@ def forward(self, x, edge_index):
'x': x.numpy(),
'edge_index': edge_index.numpy()
})[0]
assert out.shape == (3, 16)
out = torch.from_numpy(out)
assert torch.allclose(out, expected, atol=1e-6)

os.remove('model.onnx')

0 comments on commit 15dbdaf

Please sign in to comment.