From adb2afb9ba09fa7025150fd89143193e8ed64f2e Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 17 Apr 2023 08:41:05 +0000 Subject: [PATCH 1/3] update --- torch_geometric/utils/subgraph.py | 34 ++++++++++++++++++------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/torch_geometric/utils/subgraph.py b/torch_geometric/utils/subgraph.py index b20a645b6163..490e53a6b78f 100644 --- a/torch_geometric/utils/subgraph.py +++ b/torch_geometric/utils/subgraph.py @@ -87,13 +87,13 @@ def subgraph( if isinstance(subset, (list, tuple)): subset = torch.tensor(subset, dtype=torch.long, device=device) - if subset.dtype == torch.bool or subset.dtype == torch.uint8: - num_nodes = subset.size(0) - else: + if subset.dtype != torch.bool: num_nodes = maybe_num_nodes(edge_index, num_nodes) - subset = index_to_mask(subset, size=num_nodes) + node_mask = index_to_mask(subset, size=num_nodes) + else: + num_nodes = subset.size(0) + node_mask = subset - node_mask = subset edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] edge_attr = edge_attr[edge_mask] if edge_attr is not None else None @@ -101,7 +101,7 @@ def subgraph( if relabel_nodes: node_idx = torch.zeros(node_mask.size(0), dtype=torch.long, device=device) - node_idx[subset] = torch.arange(subset.sum().item(), device=device) + node_idx[subset] = torch.arange(node_mask.sum().item(), device=device) edge_index = node_idx[edge_index] if return_edge_mask: @@ -168,22 +168,28 @@ def bipartite_subgraph( if src_subset.dtype != torch.bool: src_size = int(edge_index[0].max()) + 1 if size is None else size[0] - src_subset = index_to_mask(src_subset, size=src_size) + src_node_mask = index_to_mask(src_subset, size=src_size) + else: + src_size = src_subset.size(0) + src_node_mask = src_subset + if dst_subset.dtype != torch.bool: dst_size = int(edge_index[1].max()) + 1 if size is None else size[1] - dst_subset = index_to_mask(dst_subset, size=dst_size) + dst_node_mask = index_to_mask(dst_subset, size=dst_size) + else: + dst_size = dst_subset.size(0) + dst_node_mask = dst_subset - # node_mask = subset - edge_mask = src_subset[edge_index[0]] & dst_subset[edge_index[1]] + edge_mask = src_node_mask[edge_index[0]] & dst_node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] edge_attr = edge_attr[edge_mask] if edge_attr is not None else None if relabel_nodes: - node_idx_i = edge_index.new_zeros(src_subset.size(0)) - node_idx_j = edge_index.new_zeros(dst_subset.size(0)) - node_idx_i[src_subset] = torch.arange(int(src_subset.sum()), + node_idx_i = edge_index.new_zeros(src_node_mask.size(0)) + node_idx_j = edge_index.new_zeros(dst_node_mask.size(0)) + node_idx_i[src_subset] = torch.arange(int(src_node_mask.sum()), device=node_idx_i.device) - node_idx_j[dst_subset] = torch.arange(int(dst_subset.sum()), + node_idx_j[dst_subset] = torch.arange(int(dst_node_mask.sum()), device=node_idx_j.device) edge_index = torch.stack([ node_idx_i[edge_index[0]], From 8fecd37c9dcaa836e5b117e3bc5da970c406a7d5 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 17 Apr 2023 08:43:01 +0000 Subject: [PATCH 2/3] update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65ee963b8a62..d97a830c58ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Fixed subgraph on unordered inputs ([#7187](https://github.com/pyg-team/pytorch_geometric/pull/7187)) - Allow missing node types in `HeteroDictLinear` ([#7185](https://github.com/pyg-team/pytorch_geometric/pull/7185)) - Optimized `from_networkx` memory footprint by reducing unnecessary copies ([#7119](https://github.com/pyg-team/pytorch_geometric/pull/7119)) - Added an optional `batch_size` argument to `LayerNorm`, `GraphNorm`, `InstanceNorm`, `GraphSizeNorm` and `PairNorm` ([#7135](https://github.com/pyg-team/pytorch_geometric/pull/7135)) From d6eba9480d75de61cedbcbcd3f04bc5ce211620c Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 17 Apr 2023 08:43:44 +0000 Subject: [PATCH 3/3] update --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d97a830c58ff..1316c353fcfe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Fixed subgraph on unordered inputs ([#7187](https://github.com/pyg-team/pytorch_geometric/pull/7187)) +- Fixed `subgraph` on unordered inputs ([#7187](https://github.com/pyg-team/pytorch_geometric/pull/7187)) - Allow missing node types in `HeteroDictLinear` ([#7185](https://github.com/pyg-team/pytorch_geometric/pull/7185)) - Optimized `from_networkx` memory footprint by reducing unnecessary copies ([#7119](https://github.com/pyg-team/pytorch_geometric/pull/7119)) - Added an optional `batch_size` argument to `LayerNorm`, `GraphNorm`, `InstanceNorm`, `GraphSizeNorm` and `PairNorm` ([#7135](https://github.com/pyg-team/pytorch_geometric/pull/7135))