From af0f5f44d4dfa8accd898df3e8523d06b67940b0 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 5 Mar 2024 15:01:07 +0100 Subject: [PATCH] Remove filtering of node/edge types in `trim_to_layer` (#9021) This is not safe in most cases, since a filtering of an empty edge type may lead to the unexpected drop of node features. Fixes https://github.com/pyg-team/pytorch_geometric/issues/9015 --- CHANGELOG.md | 1 + test/utils/test_trim_to_layer.py | 36 ------------------------- torch_geometric/utils/_trim_to_layer.py | 18 +------------ 3 files changed, 2 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 86c17b4064ba..7fcdf137e5ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Remove filtering of node/edge types in `trim_to_layer` functionality ([#9021](https://github.com/pyg-team/pytorch_geometric/pull/9021)) - Default to `scatter` operations in `MessagePassing` in case `torch.use_deterministic_algorithms` is not set ([#9009](https://github.com/pyg-team/pytorch_geometric/pull/9009)) - Made `MessagePassing` interface thread-safe ([#9001](https://github.com/pyg-team/pytorch_geometric/pull/9001)) - Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937)) diff --git a/test/utils/test_trim_to_layer.py b/test/utils/test_trim_to_layer.py index c1f1c0fbd038..f77bb6945360 100644 --- a/test/utils/test_trim_to_layer.py +++ b/test/utils/test_trim_to_layer.py @@ -197,39 +197,3 @@ def test_trim_to_layer_with_neighbor_loader(): assert out2.size() == (2, 16) assert torch.allclose(out1, out2, atol=1e-6) - - -def test_trim_to_layer_filtering(): - x_dict = { - 'paper': torch.rand((13, 128)), - 'author': torch.rand((5, 128)), - 'field_of_study': torch.rand((6, 128)) - } - edge_index_dict = { - ('author', 'writes', 'paper'): - torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 2, 2]]), - ('paper', 'has_topic', 'field_of_study'): - torch.tensor([[6, 7, 8, 9], [0, 0, 1, 1]]) - } - num_sampled_nodes_dict = { - 'paper': [1, 2, 10], - 'author': [0, 2, 3], - 'field_of_study': [0, 2, 4] - } - num_sampled_edges_dict = { - ('author', 'writes', 'paper'): [2, 3], - ('paper', 'has_topic', 'field_of_study'): [0, 4] - } - x_dict, edge_index_dict, _ = trim_to_layer( - layer=1, - num_sampled_nodes_per_hop=num_sampled_nodes_dict, - num_sampled_edges_per_hop=num_sampled_edges_dict, - x=x_dict, - edge_index=edge_index_dict, - ) - assert list(edge_index_dict.keys()) == [('author', 'writes', 'paper')] - assert torch.equal(edge_index_dict[('author', 'writes', 'paper')], - torch.tensor([[0, 1], [0, 0]])) - assert x_dict['paper'].size() == (3, 128) - assert x_dict['author'].size() == (2, 128) - assert x_dict['field_of_study'].size() == (2, 128) diff --git a/torch_geometric/utils/_trim_to_layer.py b/torch_geometric/utils/_trim_to_layer.py index a79a3ae92dd9..c38bb96d60e7 100644 --- a/torch_geometric/utils/_trim_to_layer.py +++ b/torch_geometric/utils/_trim_to_layer.py @@ -1,5 +1,4 @@ -import copy -from typing import Any, Dict, List, Optional, Tuple, Union, overload +from typing import Dict, List, Optional, Tuple, Union, overload import torch from torch import Tensor @@ -17,18 +16,6 @@ ) -def filter_empty_entries( - input_dict: Dict[Union[Any], Tensor]) -> Dict[Any, Tensor]: - r"""Removes empty tensors from a dictionary. This avoids unnecessary - computation when some node/edge types are non-reachable after trimming. - """ - out_dict = copy.copy(input_dict) - for key, value in input_dict.items(): - if value.numel() == 0: - del out_dict[key] - return out_dict - - @overload def trim_to_layer( layer: int, @@ -96,7 +83,6 @@ def trim_to_layer( k: trim_feat(v, layer, num_sampled_nodes_per_hop[k]) for k, v in x.items() } - x = filter_empty_entries(x) assert isinstance(edge_index, dict) edge_index = { @@ -110,7 +96,6 @@ def trim_to_layer( ) for k, v in edge_index.items() } - edge_index = filter_empty_entries(edge_index) if edge_attr is not None: assert isinstance(edge_attr, dict) @@ -118,7 +103,6 @@ def trim_to_layer( k: trim_feat(v, layer, num_sampled_edges_per_hop[k]) for k, v in edge_attr.items() } - edge_attr = filter_empty_entries(edge_attr) return x, edge_index, edge_attr