diff --git a/CHANGELOG.md b/CHANGELOG.md index a136305d0507..3cc3669092fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -73,6 +73,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Fixed `RGATConv` device mismatches for `f-scaled` mode ([#5187](https://github.com/pyg-team/pytorch_geometric/pull/5187)] - Allow for multi-dimensional `edge_labels` in `LinkNeighborLoader` ([#5186](https://github.com/pyg-team/pytorch_geometric/pull/5186)] - Fixed `GINEConv` bug with non-sequential input ([#5154](https://github.com/pyg-team/pytorch_geometric/pull/5154)] - Improved error message ([#5095](https://github.com/pyg-team/pytorch_geometric/pull/5095)) diff --git a/torch_geometric/nn/conv/rgat_conv.py b/torch_geometric/nn/conv/rgat_conv.py index fd13988b43ef..9fa96fc648e9 100644 --- a/torch_geometric/nn/conv/rgat_conv.py +++ b/torch_geometric/nn/conv/rgat_conv.py @@ -441,7 +441,7 @@ def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor, elif self.mod == "scaled": if self.attention_mode == "additive-self-attention": - ones = torch.ones(index.size()) + ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 @@ -453,7 +453,7 @@ def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor, alpha.view(-1, self.heads, 1), degree.view(-1, 1, self.out_channels)) elif self.attention_mode == "multiplicative-self-attention": - ones = torch.ones(index.size()) + ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 @@ -469,7 +469,7 @@ def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor, alpha = torch.where(alpha > 0, alpha + 1, alpha) elif self.mod == "f-scaled": - ones = torch.ones(index.size()) + ones = alpha.new_ones(index.size()) degree = scatter_add(ones, index, dim_size=size_i)[index].unsqueeze(-1) alpha = alpha * degree