|
5 | 5 | import torch.nn.functional as F
|
6 | 6 |
|
7 | 7 | from torch_geometric.data import (
|
| 8 | + HeteroData, |
8 | 9 | LightningDataset,
|
9 | 10 | LightningLinkData,
|
10 | 11 | LightningNodeData,
|
|
18 | 19 | LightningModule = torch.nn.Module
|
19 | 20 |
|
20 | 21 |
|
| 22 | +def get_edge_index(num_src_nodes, num_dst_nodes, num_edges): |
| 23 | + row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long) |
| 24 | + col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long) |
| 25 | + return torch.stack([row, col], dim=0) |
| 26 | + |
| 27 | + |
21 | 28 | class LinearGraphModule(LightningModule):
|
22 | 29 | def __init__(self, in_channels, hidden_channels, out_channels):
|
23 | 30 | super().__init__()
|
@@ -273,10 +280,18 @@ def test_lightning_hetero_node_data(get_dataset):
|
273 | 280 | @withCUDA
|
274 | 281 | @onlyFullTest
|
275 | 282 | @withPackage('pytorch_lightning')
|
276 |
| -def test_lightning_hetero_link_data(get_dataset): |
277 |
| - # TODO: Add more datasets. |
278 |
| - dataset = get_dataset(name='DBLP') |
279 |
| - data = dataset[0] |
| 283 | +def test_lightning_hetero_link_data(): |
| 284 | + torch.manual_seed(12345) |
| 285 | + |
| 286 | + data = HeteroData() |
| 287 | + |
| 288 | + data['paper'].x = torch.arange(10) |
| 289 | + data['author'].x = torch.arange(10) |
| 290 | + data['term'].x = torch.arange(10) |
| 291 | + |
| 292 | + data['paper', 'author'].edge_index = get_edge_index(10, 10, 10) |
| 293 | + data['author', 'paper'].edge_index = get_edge_index(10, 10, 10) |
| 294 | + data['paper', 'term'].edge_index = get_edge_index(10, 10, 10) |
280 | 295 |
|
281 | 296 | datamodule = LightningLinkData(
|
282 | 297 | data,
|
|
0 commit comments