Skip to content

Commit

Permalink
Fix subgraph on unordered inputs (#7187)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Apr 17, 2023
1 parent 8999e50 commit 3d4836b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Fixed `subgraph` on unordered inputs ([#7187](https://github.com/pyg-team/pytorch_geometric/pull/7187))
- Allow missing node types in `HeteroDictLinear` ([#7185](https://github.com/pyg-team/pytorch_geometric/pull/7185))
- Optimized `from_networkx` memory footprint by reducing unnecessary copies ([#7119](https://github.com/pyg-team/pytorch_geometric/pull/7119))
- Added an optional `batch_size` argument to `LayerNorm`, `GraphNorm`, `InstanceNorm`, `GraphSizeNorm` and `PairNorm` ([#7135](https://github.com/pyg-team/pytorch_geometric/pull/7135))
Expand Down
34 changes: 20 additions & 14 deletions torch_geometric/utils/subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,21 @@ def subgraph(
if isinstance(subset, (list, tuple)):
subset = torch.tensor(subset, dtype=torch.long, device=device)

if subset.dtype == torch.bool or subset.dtype == torch.uint8:
num_nodes = subset.size(0)
else:
if subset.dtype != torch.bool:
num_nodes = maybe_num_nodes(edge_index, num_nodes)
subset = index_to_mask(subset, size=num_nodes)
node_mask = index_to_mask(subset, size=num_nodes)
else:
num_nodes = subset.size(0)
node_mask = subset

node_mask = subset
edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
edge_index = edge_index[:, edge_mask]
edge_attr = edge_attr[edge_mask] if edge_attr is not None else None

if relabel_nodes:
node_idx = torch.zeros(node_mask.size(0), dtype=torch.long,
device=device)
node_idx[subset] = torch.arange(subset.sum().item(), device=device)
node_idx[subset] = torch.arange(node_mask.sum().item(), device=device)
edge_index = node_idx[edge_index]

if return_edge_mask:
Expand Down Expand Up @@ -168,22 +168,28 @@ def bipartite_subgraph(

if src_subset.dtype != torch.bool:
src_size = int(edge_index[0].max()) + 1 if size is None else size[0]
src_subset = index_to_mask(src_subset, size=src_size)
src_node_mask = index_to_mask(src_subset, size=src_size)
else:
src_size = src_subset.size(0)
src_node_mask = src_subset

if dst_subset.dtype != torch.bool:
dst_size = int(edge_index[1].max()) + 1 if size is None else size[1]
dst_subset = index_to_mask(dst_subset, size=dst_size)
dst_node_mask = index_to_mask(dst_subset, size=dst_size)
else:
dst_size = dst_subset.size(0)
dst_node_mask = dst_subset

# node_mask = subset
edge_mask = src_subset[edge_index[0]] & dst_subset[edge_index[1]]
edge_mask = src_node_mask[edge_index[0]] & dst_node_mask[edge_index[1]]
edge_index = edge_index[:, edge_mask]
edge_attr = edge_attr[edge_mask] if edge_attr is not None else None

if relabel_nodes:
node_idx_i = edge_index.new_zeros(src_subset.size(0))
node_idx_j = edge_index.new_zeros(dst_subset.size(0))
node_idx_i[src_subset] = torch.arange(int(src_subset.sum()),
node_idx_i = edge_index.new_zeros(src_node_mask.size(0))
node_idx_j = edge_index.new_zeros(dst_node_mask.size(0))
node_idx_i[src_subset] = torch.arange(int(src_node_mask.sum()),
device=node_idx_i.device)
node_idx_j[dst_subset] = torch.arange(int(dst_subset.sum()),
node_idx_j[dst_subset] = torch.arange(int(dst_node_mask.sum()),
device=node_idx_j.device)
edge_index = torch.stack([
node_idx_i[edge_index[0]],
Expand Down

0 comments on commit 3d4836b

Please sign in to comment.