From b3c03185c848d210f6ae5d3416e52851c8c68102 Mon Sep 17 00:00:00 2001 From: Francesco Landolfi Date: Wed, 14 Sep 2022 17:11:48 +0200 Subject: [PATCH] Fix `node_attributes` shape in `read_tu_data` (#5441) * Fix node_attributes shape * Update tu.py * Update CHANGELOG.md Co-authored-by: Matthias Fey --- CHANGELOG.md | 1 + torch_geometric/io/tu.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c6b16447fce..70174781bb5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240)) - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) ### Changed +- Fixed a bug in `TUDataset` in which node features were wrongly constructed whenever `node_attributes` only hold a single feature (*e.g.*, in `PROTEINS`) ([#5441](https://github.com/pyg-team/pytorch_geometric/pull/5411)) - Breaking change: removed `num_neighbors` as an attribute of loader ([#5404](https://github.com/pyg-team/pytorch_geometric/pull/5404)) - `ASAPooling` is now jittable ([#5395](https://github.com/pyg-team/pytorch_geometric/pull/5395)) - Updated unsupervised `GraphSAGE` example to leverage `LinkNeighborLoader` ([#5317](https://github.com/pyg-team/pytorch_geometric/pull/5317)) diff --git a/torch_geometric/io/tu.py b/torch_geometric/io/tu.py index 4d85d8ea05cc..5d23f68d2b14 100644 --- a/torch_geometric/io/tu.py +++ b/torch_geometric/io/tu.py @@ -27,6 +27,8 @@ def read_tu_data(folder, prefix): node_attributes = torch.empty((batch.size(0), 0)) if 'node_attributes' in names: node_attributes = read_file(folder, prefix, 'node_attributes') + if node_attributes.dim() == 1: + node_attributes = node_attributes.unsqueeze(-1) node_labels = torch.empty((batch.size(0), 0)) if 'node_labels' in names: @@ -41,6 +43,8 @@ def read_tu_data(folder, prefix): edge_attributes = torch.empty((edge_index.size(1), 0)) if 'edge_attributes' in names: edge_attributes = read_file(folder, prefix, 'edge_attributes') + if edge_attributes.dim() == 1: + edge_attributes = edge_attributes.unsqueeze(-1) edge_labels = torch.empty((edge_index.size(1), 0)) if 'edge_labels' in names: