Skip to content

Commit

Permalink
Fix errors of the RevGNN example (#4715)
Browse files Browse the repository at this point in the history
* Fix errors

* changelog

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
lightaime and rusty1s authored May 25, 2022
1 parent f482cb7 commit efffdc3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715))
- Added benchmarks via [`wandb`](https://wandb.ai/site) ([#4656](https://github.com/pyg-team/pytorch_geometric/pull/4656), [#4672](https://github.com/pyg-team/pytorch_geometric/pull/4672), [#4676](https://github.com/pyg-team/pytorch_geometric/pull/4676))
- Added `unbatch` functionality ([#4628](https://github.com/pyg-team/pytorch_geometric/pull/4628))
- Confirm that `to_hetero()` works with custom functions, *e.g.*, `dropout_adj` ([4653](https://github.com/pyg-team/pytorch_geometric/pull/4653))
Expand Down
7 changes: 4 additions & 3 deletions examples/rev_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class GNNBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__(in_channels)
super().__init__()
self.norm = LayerNorm(in_channels, elementwise_affine=True)
self.conv = SAGEConv(in_channels, out_channels)

Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self, in_channels, hidden_channels, out_channels, num_layers,

assert hidden_channels % num_groups == 0
self.convs = torch.nn.ModuleList()
for _ in range(self.num_layers):
for _ in range(num_layers):
conv = GNNBlock(
hidden_channels // num_groups,
hidden_channels // num_groups,
Expand All @@ -63,14 +63,15 @@ def reset_parameters(self):
conv.reset_parameters()

def forward(self, x, edge_index):
x = self.lin1(x)

# Generate a dropout mask which will be shared across GNN blocks:
mask = None
if self.training and self.dropout > 0:
mask = torch.zeros_like(x).bernoulli_(1 - self.dropout)
mask = mask.requires_grad_(False)
mask = mask / (1 - self.dropout)

x = self.lin1(x)
for conv in self.convs:
x = conv(x, edge_index, mask)
x = self.norm(x).relu()
Expand Down

0 comments on commit efffdc3

Please sign in to comment.