From 6b651a3d7ae93e916099d34753cdb28868bb7748 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Thu, 21 Apr 2022 08:07:19 +0800 Subject: [PATCH 1/7] test label types --- test/loader/test_link_neighbor_loader.py | 31 +++++++++++++++++++ .../loader/link_neighbor_loader.py | 5 +-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/test/loader/test_link_neighbor_loader.py b/test/loader/test_link_neighbor_loader.py index 76b1a7099975..14d9c12d58e1 100644 --- a/test/loader/test_link_neighbor_loader.py +++ b/test/loader/test_link_neighbor_loader.py @@ -155,3 +155,34 @@ def test_heterogeneous_link_neighbor_loader_loop(directed): edge_label_index = batch['paper', 'paper'].edge_label_index edge_label_index = unique_edge_pairs(edge_label_index) assert len(edge_index | edge_label_index) == len(edge_index) + + +def test_link_neighbour_loader_labels(): + + torch.manual_seed(12345) + + edge_index = get_edge_index(100, 100, 500) + data = Data(edge_index=edge_index, x=torch.arange(100)) + + loader = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + batch_size=10, + neg_sampling_ratio=1, + ) + + for batch in loader: + assert all(batch.edge_label[:10] == 1) + assert all(batch.edge_label[10:] == 0) + + loader = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + batch_size=10, + edge_label=torch.ones(500).type(torch.long), + neg_sampling_ratio=1, + ) + + for batch in loader: + assert all(batch.edge_label[:10] == 2) + assert all(batch.edge_label[10:] == 0) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 3cb2a6645bea..319727738aac 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -31,7 +31,8 @@ def _create_label(self, edge_label_index, edge_label): return edge_label_index, edge_label if edge_label is None: - edge_label = torch.ones(num_pos_edges, device=device) + edge_label = torch.ones(num_pos_edges, + device=device).type(torch.long) else: assert edge_label.dtype == torch.long edge_label = edge_label + 1 @@ -170,7 +171,7 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): train_mask=[1368], val_mask=[1368], test_mask=[1368], edge_label_index=[2, 128], edge_label=[128]) - The rest of the functionality mirros that of + The rest of the functionality mirrors that of :class:`~torch_geometric.loader.NeighborLoader`, including support for heterogenous graphs. From 06a7301e0d89519b29c21ebc2c7308920a10dafa Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Thu, 21 Apr 2022 22:13:38 +0800 Subject: [PATCH 2/7] update docstring --- torch_geometric/loader/link_neighbor_loader.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 319727738aac..b5d293fc0593 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -31,8 +31,7 @@ def _create_label(self, edge_label_index, edge_label): return edge_label_index, edge_label if edge_label is None: - edge_label = torch.ones(num_pos_edges, - device=device).type(torch.long) + edge_label = torch.ones(num_pos_edges, device=device) else: assert edge_label.dtype == torch.long edge_label = edge_label + 1 @@ -214,7 +213,10 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): :obj:`0` to :obj:`num_classes - 1`. After negative sampling, label :obj:`0` represents negative edges, and labels :obj:`1` to :obj:`num_classes` represent the labels of - positive edges. (default: :obj:`0.0`) + positive edges. (default: :obj:`0.0`). + Note that returned labels are :obj:`torch.float` when it is binary + classification (to facilitate easy use with cross entropy loss) + and :obj:`torch.long` when is multiclass. **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. From 6bbf4802a70df7df6efe40c1fb98c6266eb4e2b2 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Thu, 21 Apr 2022 22:16:49 +0800 Subject: [PATCH 3/7] update docstring --- torch_geometric/loader/link_neighbor_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index b5d293fc0593..0e47aa3e9799 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -215,8 +215,8 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): and labels :obj:`1` to :obj:`num_classes` represent the labels of positive edges. (default: :obj:`0.0`). Note that returned labels are :obj:`torch.float` when it is binary - classification (to facilitate easy use with cross entropy loss) - and :obj:`torch.long` when is multiclass. + classification (to facilitate easy use with :obj:`BCE_loss`) + and :obj:`torch.long` when is a multiclass classification problem. **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. From b6f70fad2238c4d46e0bd5b3f3ec5fb061c9497e Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Fri, 22 Apr 2022 07:57:23 +0800 Subject: [PATCH 4/7] Update torch_geometric/loader/link_neighbor_loader.py Co-authored-by: Matthias Fey --- torch_geometric/loader/link_neighbor_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 2e948f67cb1e..23c984a9c55c 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -216,7 +216,7 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): positive edges. (default: :obj:`0.0`). Note that returned labels are :obj:`torch.float` when it is binary classification (to facilitate easy use with :obj:`BCE_loss`) - and :obj:`torch.long` when is a multiclass classification problem. + and :obj:`torch.long` when it is a multi-class classification problem. **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. From e50001545d90b2ced32d9dddda10d170ad978f13 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Fri, 22 Apr 2022 07:57:43 +0800 Subject: [PATCH 5/7] Update torch_geometric/loader/link_neighbor_loader.py Co-authored-by: Matthias Fey --- torch_geometric/loader/link_neighbor_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 23c984a9c55c..f317a0dff526 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -215,7 +215,7 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): and labels :obj:`1` to :obj:`num_classes` represent the labels of positive edges. (default: :obj:`0.0`). Note that returned labels are :obj:`torch.float` when it is binary - classification (to facilitate easy use with :obj:`BCE_loss`) + classification (to facilitate ease-of-use with :meth:`torch.binary_cross_entropy_loss`) and :obj:`torch.long` when it is a multi-class classification problem. **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, From 39e522116fafa58d6b47489ea6fbc374c532405c Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Fri, 22 Apr 2022 08:21:12 +0800 Subject: [PATCH 6/7] line lengtH --- torch_geometric/loader/link_neighbor_loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index f317a0dff526..ad5c58139d4d 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -215,8 +215,9 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): and labels :obj:`1` to :obj:`num_classes` represent the labels of positive edges. (default: :obj:`0.0`). Note that returned labels are :obj:`torch.float` when it is binary - classification (to facilitate ease-of-use with :meth:`torch.binary_cross_entropy_loss`) - and :obj:`torch.long` when it is a multi-class classification problem. + classification (to facilitate ease-of-use with + :meth:`torch.binary_cross_entropy_loss`) and :obj:`torch.long` + when it is a multi-class classification problem. **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. From 7a676854d48414367e53b998307037aa426d5f4e Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 22 Apr 2022 06:49:58 +0200 Subject: [PATCH 7/7] typo --- test/loader/test_link_neighbor_loader.py | 19 ++++++++++--------- .../loader/link_neighbor_loader.py | 11 ++++++----- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/test/loader/test_link_neighbor_loader.py b/test/loader/test_link_neighbor_loader.py index 14d9c12d58e1..f2dd71f415f0 100644 --- a/test/loader/test_link_neighbor_loader.py +++ b/test/loader/test_link_neighbor_loader.py @@ -157,8 +157,7 @@ def test_heterogeneous_link_neighbor_loader_loop(directed): assert len(edge_index | edge_label_index) == len(edge_index) -def test_link_neighbour_loader_labels(): - +def test_link_neighbor_loader_edge_label(): torch.manual_seed(12345) edge_index = get_edge_index(100, 100, 500) @@ -168,21 +167,23 @@ def test_link_neighbour_loader_labels(): data, num_neighbors=[-1] * 2, batch_size=10, - neg_sampling_ratio=1, + neg_sampling_ratio=1.0, ) for batch in loader: - assert all(batch.edge_label[:10] == 1) - assert all(batch.edge_label[10:] == 0) + assert batch.edge_label.dtype == torch.float + assert torch.all(batch.edge_label[:10] == 1.0) + assert torch.all(batch.edge_label[10:] == 0.0) loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, batch_size=10, - edge_label=torch.ones(500).type(torch.long), - neg_sampling_ratio=1, + edge_label=torch.ones(500, dtype=torch.long), + neg_sampling_ratio=1.0, ) for batch in loader: - assert all(batch.edge_label[:10] == 2) - assert all(batch.edge_label[10:] == 0) + assert batch.edge_label.dtype == torch.long + assert torch.all(batch.edge_label[:10] == 2) + assert torch.all(batch.edge_label[10:] == 0) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index ad5c58139d4d..08390100e2f0 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -213,11 +213,12 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): :obj:`0` to :obj:`num_classes - 1`. After negative sampling, label :obj:`0` represents negative edges, and labels :obj:`1` to :obj:`num_classes` represent the labels of - positive edges. (default: :obj:`0.0`). - Note that returned labels are :obj:`torch.float` when it is binary - classification (to facilitate ease-of-use with - :meth:`torch.binary_cross_entropy_loss`) and :obj:`torch.long` - when it is a multi-class classification problem. + positive edges. + Note that returned labels are of type :obj:`torch.float` for binary + classification (to facilitate the ease-of-use of + :meth:`F.binary_cross_entropy`) and of type + :obj:`torch.long` for multi-class classification (to facilitate the + ease-of-use of :meth:`F.cross_entropy`). (default: :obj:`0.0`). **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.