Skip to content

Commit

Permalink
Fix device bug in get_degree_histogram (#7830)
Browse files Browse the repository at this point in the history
This PR fixes the following error:
`RuntimeError: Expected all tensors to be on the same device, but found
at least two devices, xpu:0 and cpu!`
`deg_histogram` now inherits the device type from `edge_index`.

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
DamianSzwichtenberg and rusty1s authored Aug 2, 2023
1 parent bb612f8 commit 20362ee
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `PrefetchLoader` capabilities ([#7376](https://github.com/pyg-team/pytorch_geometric/pull/7376), [#7378](https://github.com/pyg-team/pytorch_geometric/pull/7378), [#7383](https://github.com/pyg-team/pytorch_geometric/pull/7383))
- Added an example for hierarchical sampling ([#7244](https://github.com/pyg-team/pytorch_geometric/pull/7244))
- Added Kùzu remote backend examples ([#7298](https://github.com/pyg-team/pytorch_geometric/pull/7298))
- Fixed tracing of `add_self_loops` for a dynamic number of nodes ([#7330](https://github.com/pyg-team/pytorch_geometric/pull/7330))
- Added an optional `add_pad_mask` argument to the `Pad` transform ([#7339](https://github.com/pyg-team/pytorch_geometric/pull/7339))
- Added `keep_inter_cluster_edges` option to `ClusterData` to support inter-subgraph edge connections when doing graph partitioning ([#7326](https://github.com/pyg-team/pytorch_geometric/pull/7326))
- Unify graph pooling framework ([#7308](https://github.com/pyg-team/pytorch_geometric/pull/7308), [#7625](https://github.com/pyg-team/pytorch_geometric/pull/7625))
Expand Down Expand Up @@ -78,6 +77,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Fixed tracing of `add_self_loops` for a dynamic number of nodes ([#7330](https://github.com/pyg-team/pytorch_geometric/pull/7330))
- Fixed device issue in `PNAConv.get_degree_histogram` ([#7830](https://github.com/pyg-team/pytorch_geometric/pull/7830))
- Fixed the shape of `edge_label_time` when using temporal sampling on homogeneous graphs ([#7807](https://github.com/pyg-team/pytorch_geometric/pull/7807))
- Made `FieldStatus` enum picklable to avoid `PicklingError` in a multi-process setting ([#7808](https://github.com/pyg-team/pytorch_geometric/pull/7808))
- Fixed `edge_label_index` computation in `LinkNeighborLoader` for the homogeneous+`disjoint` mode ([#7791](https://github.com/pyg-team/pytorch_geometric/pull/7791))
Expand Down
17 changes: 9 additions & 8 deletions torch_geometric/nn/conv/pna_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,15 @@ def get_degree_histogram(loader: DataLoader) -> Tensor:
argument in :class:`PNAConv`."""
deg_histogram = torch.zeros(1, dtype=torch.long)
for data in loader:
d = degree(data.edge_index[1], num_nodes=data.num_nodes,
dtype=torch.long)
d_bincount = torch.bincount(d, minlength=deg_histogram.numel())
if d_bincount.size(0) > deg_histogram.size(0):
d_bincount[:deg_histogram.size(0)] += deg_histogram
deg_histogram = d_bincount
deg = degree(data.edge_index[1], num_nodes=data.num_nodes,
dtype=torch.long)
deg_bincount = torch.bincount(deg, minlength=deg_histogram.numel())
deg_histogram = deg_histogram.to(deg_bincount.device)
if deg_bincount.numel() > deg_histogram.numel():
deg_bincount[:deg_histogram.size(0)] += deg_histogram
deg_histogram = deg_bincount
else:
assert d_bincount.size(0) == deg_histogram.size(0)
deg_histogram += d_bincount
assert deg_bincount.numel() == deg_histogram.numel()
deg_histogram += deg_bincount

return deg_histogram

0 comments on commit 20362ee

Please sign in to comment.