Skip to content

Commit

Permalink
Fit RGATConv different device error (#5187)
Browse files Browse the repository at this point in the history
* fix rgatconv device

* changelog
  • Loading branch information
rusty1s authored Aug 10, 2022
1 parent 797e3d9 commit fea52ab
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Fixed `RGATConv` device mismatches for `f-scaled` mode ([#5187](https://github.com/pyg-team/pytorch_geometric/pull/5187)]
- Allow for multi-dimensional `edge_labels` in `LinkNeighborLoader` ([#5186](https://github.com/pyg-team/pytorch_geometric/pull/5186)]
- Fixed `GINEConv` bug with non-sequential input ([#5154](https://github.com/pyg-team/pytorch_geometric/pull/5154)]
- Improved error message ([#5095](https://github.com/pyg-team/pytorch_geometric/pull/5095))
Expand Down
6 changes: 3 additions & 3 deletions torch_geometric/nn/conv/rgat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor,

elif self.mod == "scaled":
if self.attention_mode == "additive-self-attention":
ones = torch.ones(index.size())
ones = alpha.new_ones(index.size())
degree = scatter_add(ones, index,
dim_size=size_i)[index].unsqueeze(-1)
degree = torch.matmul(degree, self.l1) + self.b1
Expand All @@ -453,7 +453,7 @@ def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor,
alpha.view(-1, self.heads, 1),
degree.view(-1, 1, self.out_channels))
elif self.attention_mode == "multiplicative-self-attention":
ones = torch.ones(index.size())
ones = alpha.new_ones(index.size())
degree = scatter_add(ones, index,
dim_size=size_i)[index].unsqueeze(-1)
degree = torch.matmul(degree, self.l1) + self.b1
Expand All @@ -469,7 +469,7 @@ def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor,
alpha = torch.where(alpha > 0, alpha + 1, alpha)

elif self.mod == "f-scaled":
ones = torch.ones(index.size())
ones = alpha.new_ones(index.size())
degree = scatter_add(ones, index,
dim_size=size_i)[index].unsqueeze(-1)
alpha = alpha * degree
Expand Down

0 comments on commit fea52ab

Please sign in to comment.