Skip to content

Commit

Permalink
[Code Coverage] HeteroConv (#6568)
Browse files Browse the repository at this point in the history
`HeteroConv` code coverage improvement
  • Loading branch information
zechengz authored Feb 2, 2023
1 parent 516220d commit 5868159
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
34 changes: 22 additions & 12 deletions test/nn/conv/test_hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,27 @@ def test_hetero_conv(aggr):
data['author'].x = torch.randn(30, 64)
data['paper', 'paper'].edge_index = get_edge_index(50, 50, 200)
data['paper', 'author'].edge_index = get_edge_index(50, 30, 100)
data['paper', 'author'].edge_attr = torch.randn(100, 3)
data['author', 'paper'].edge_index = get_edge_index(30, 50, 100)
data['paper', 'paper'].edge_weight = torch.rand(200)

# Unspecified edge types should be ignored:
data['author', 'author'].edge_index = get_edge_index(30, 30, 100)

conv = HeteroConv(
{
('paper', 'to', 'paper'): GCNConv(-1, 64),
('author', 'to', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'to', 'author'): GATConv(
(-1, -1), 64, add_self_loops=False),
('paper', 'to', 'paper'):
GCNConv(-1, 64),
('author', 'to', 'paper'):
SAGEConv((-1, -1), 64),
('paper', 'to', 'author'):
GATConv((-1, -1), 64, edge_dim=3, add_self_loops=False),
}, aggr=aggr)

assert len(list(conv.parameters())) > 0
assert str(conv) == 'HeteroConv(num_relations=3)'

out = conv(data.x_dict, data.edge_index_dict,
out = conv(data.x_dict, data.edge_index_dict, data.edge_attr_dict,
edge_weight_dict=data.edge_weight_dict)

assert len(out) == 2
Expand All @@ -56,25 +62,29 @@ def __init__(self, out_channels):
super().__init__(aggr='add')
self.lin = Linear(-1, out_channels)

def forward(self, x, edge_index, pos):
return self.propagate(edge_index, x=x, pos=pos)
def forward(self, x, edge_index, y, z):
return self.propagate(edge_index, x=x, y=y, z=z)

def message(self, x_j, pos_i, pos_j):
return self.lin(torch.cat([x_j, pos_i - pos_j], dim=-1))
def message(self, x_j, y_j, z_j):
return self.lin(torch.cat([x_j, y_j, z_j], dim=-1))


def test_hetero_conv_with_custom_conv():
data = HeteroData()
data['paper'].x = torch.randn(50, 32)
data['paper'].pos = torch.randn(50, 3)
data['paper'].y = torch.randn(50, 3)
data['paper'].z = torch.randn(50, 3)
data['author'].x = torch.randn(30, 64)
data['author'].pos = torch.randn(30, 3)
data['author'].y = torch.randn(30, 3)
data['author'].z = torch.randn(30, 3)
data['paper', 'paper'].edge_index = get_edge_index(50, 50, 200)
data['paper', 'author'].edge_index = get_edge_index(50, 30, 100)
data['author', 'paper'].edge_index = get_edge_index(30, 50, 100)

conv = HeteroConv({key: CustomConv(64) for key in data.edge_types})
out = conv(data.x_dict, data.edge_index_dict, data.pos_dict)
# Test node `args_dict` and `kwargs_dict` with `y_dict` and `z_dict`:
out = conv(data.x_dict, data.edge_index_dict, data.y_dict,
z_dict=data.z_dict)
assert len(out) == 2
assert out['paper'].size() == (50, 64)
assert out['author'].size() == (30, 64)
Expand Down

0 comments on commit 5868159

Please sign in to comment.