diff --git a/CHANGELOG.md b/CHANGELOG.md index 65ee963b8a62..1316c353fcfe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Fixed `subgraph` on unordered inputs ([#7187](https://github.com/pyg-team/pytorch_geometric/pull/7187)) - Allow missing node types in `HeteroDictLinear` ([#7185](https://github.com/pyg-team/pytorch_geometric/pull/7185)) - Optimized `from_networkx` memory footprint by reducing unnecessary copies ([#7119](https://github.com/pyg-team/pytorch_geometric/pull/7119)) - Added an optional `batch_size` argument to `LayerNorm`, `GraphNorm`, `InstanceNorm`, `GraphSizeNorm` and `PairNorm` ([#7135](https://github.com/pyg-team/pytorch_geometric/pull/7135)) diff --git a/torch_geometric/utils/subgraph.py b/torch_geometric/utils/subgraph.py index b20a645b6163..490e53a6b78f 100644 --- a/torch_geometric/utils/subgraph.py +++ b/torch_geometric/utils/subgraph.py @@ -87,13 +87,13 @@ def subgraph( if isinstance(subset, (list, tuple)): subset = torch.tensor(subset, dtype=torch.long, device=device) - if subset.dtype == torch.bool or subset.dtype == torch.uint8: - num_nodes = subset.size(0) - else: + if subset.dtype != torch.bool: num_nodes = maybe_num_nodes(edge_index, num_nodes) - subset = index_to_mask(subset, size=num_nodes) + node_mask = index_to_mask(subset, size=num_nodes) + else: + num_nodes = subset.size(0) + node_mask = subset - node_mask = subset edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] edge_attr = edge_attr[edge_mask] if edge_attr is not None else None @@ -101,7 +101,7 @@ def subgraph( if relabel_nodes: node_idx = torch.zeros(node_mask.size(0), dtype=torch.long, device=device) - node_idx[subset] = torch.arange(subset.sum().item(), device=device) + node_idx[subset] = torch.arange(node_mask.sum().item(), device=device) edge_index = node_idx[edge_index] if return_edge_mask: @@ -168,22 +168,28 @@ def bipartite_subgraph( if src_subset.dtype != torch.bool: src_size = int(edge_index[0].max()) + 1 if size is None else size[0] - src_subset = index_to_mask(src_subset, size=src_size) + src_node_mask = index_to_mask(src_subset, size=src_size) + else: + src_size = src_subset.size(0) + src_node_mask = src_subset + if dst_subset.dtype != torch.bool: dst_size = int(edge_index[1].max()) + 1 if size is None else size[1] - dst_subset = index_to_mask(dst_subset, size=dst_size) + dst_node_mask = index_to_mask(dst_subset, size=dst_size) + else: + dst_size = dst_subset.size(0) + dst_node_mask = dst_subset - # node_mask = subset - edge_mask = src_subset[edge_index[0]] & dst_subset[edge_index[1]] + edge_mask = src_node_mask[edge_index[0]] & dst_node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] edge_attr = edge_attr[edge_mask] if edge_attr is not None else None if relabel_nodes: - node_idx_i = edge_index.new_zeros(src_subset.size(0)) - node_idx_j = edge_index.new_zeros(dst_subset.size(0)) - node_idx_i[src_subset] = torch.arange(int(src_subset.sum()), + node_idx_i = edge_index.new_zeros(src_node_mask.size(0)) + node_idx_j = edge_index.new_zeros(dst_node_mask.size(0)) + node_idx_i[src_subset] = torch.arange(int(src_node_mask.sum()), device=node_idx_i.device) - node_idx_j[dst_subset] = torch.arange(int(dst_subset.sum()), + node_idx_j[dst_subset] = torch.arange(int(dst_node_mask.sum()), device=node_idx_j.device) edge_index = torch.stack([ node_idx_i[edge_index[0]],