diff --git a/CHANGELOG.md b/CHANGELOG.md index f994d3125195..f4b192d5ad22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Do not fill `InMemoryDataset` cache on `dataset.num_features` ([#5264](https://github.com/pyg-team/pytorch_geometric/pull/5264)) - Changed tests relying on `dblp` datasets to instead use synthetic data ([#5250](https://github.com/pyg-team/pytorch_geometric/pull/5250)) - Fixed a bug for the initialization of activation function examples in `custom_graphgym` ([#5243](https://github.com/pyg-team/pytorch_geometric/pull/5243)) +- Allow any integer tensors when checking edge_index input to message passing ([5281](https://github.com/pyg-team/pytorch_geometric/pull/5281)) ### Removed ## [2.1.0] - 2022-08-17 diff --git a/test/nn/conv/test_message_passing.py b/test/nn/conv/test_message_passing.py index b2c74f91d75d..4c0a9bed996f 100644 --- a/test/nn/conv/test_message_passing.py +++ b/test/nn/conv/test_message_passing.py @@ -522,3 +522,22 @@ def test_message_passing_with_aggr_module(aggr_module): out = conv(x, edge_index) assert out.size(0) == 4 and out.size(1) in {8, 16} assert torch.allclose(conv(x, adj.t()), out) + + +def test_message_passing_int32_edge_index(): + # Check that we can dispatch an int32 edge_index up to aggregation + x = torch.randn(4, 8) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], dtype=torch.int32) + edge_weight = torch.randn(edge_index.shape[1]) + + # Use a hook to promote the edge_index to long to workaround PyTorch CPU + # backend restriction to int64 for the index. + def cast_index_hook(module, inputs): + input_dict = inputs[-1] + input_dict['index'] = input_dict['index'].long() + return (input_dict, ) + + conv = MyConv(8, 32) + conv.register_aggregate_forward_pre_hook(cast_index_hook) + + assert conv(x, edge_index, edge_weight).size() == (4, 32) diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index 20498c2416cf..265112742675 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -183,9 +183,11 @@ def __check_input__(self, edge_index, size): the_size: List[Optional[int]] = [None, None] if isinstance(edge_index, Tensor): - if not edge_index.dtype == torch.long: - raise ValueError(f"Expected 'edge_index' to be of type " - f"'torch.long' (got '{edge_index.dtype}')") + int_dtypes = (torch.uint8, torch.int8, torch.int32, torch.int64) + + if edge_index.dtype not in int_dtypes: + raise ValueError(f"Expected 'edge_index' to be of integer " + f"type (got '{edge_index.dtype}')") if edge_index.dim() != 2: raise ValueError(f"Expected 'edge_index' to be two-dimensional" f" (got {edge_index.dim()} dimensions)") @@ -211,7 +213,7 @@ def __check_input__(self, edge_index, size): return the_size raise ValueError( - ('`MessagePassing.propagate` only supports `torch.LongTensor` of ' + ('`MessagePassing.propagate` only supports integer tensors of ' 'shape `[2, num_messages]` or `torch_sparse.SparseTensor` for ' 'argument `edge_index`.'))