From ad82018fe23e5f7b4e54bce03033ef8a0af3658a Mon Sep 17 00:00:00 2001 From: migalkin Date: Fri, 20 May 2022 15:55:01 -0400 Subject: [PATCH 1/3] compute ranks fix --- examples/rgcn_link_pred.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/rgcn_link_pred.py b/examples/rgcn_link_pred.py index 6c988ed931f1..df1ab9610f62 100644 --- a/examples/rgcn_link_pred.py +++ b/examples/rgcn_link_pred.py @@ -112,6 +112,14 @@ def test(): return valid_mrr, test_mrr +@torch.no_grad() +def compute_rank(ranks): + # fair ranking prediction as the average of optimistic and pessimistic ranking + true = ranks[0] + optimistic = (ranks > true).sum() + 1 + pessimistic = (ranks >= true).sum() + return (optimistic + pessimistic).float() * 0.5 + @torch.no_grad() def compute_mrr(z, edge_index, edge_type): @@ -135,9 +143,8 @@ def compute_mrr(z, edge_index, edge_type): eval_edge_type = torch.full_like(tail, fill_value=rel) out = model.decode(z, eval_edge_index, eval_edge_type) - perm = out.argsort(descending=True) - rank = int((perm == 0).nonzero(as_tuple=False).view(-1)[0]) - ranks.append(rank + 1) + rank = compute_rank(out) + ranks.append(rank) # Try all nodes as heads, but delete true triplets: head_mask = torch.ones(data.num_nodes, dtype=torch.bool) @@ -155,9 +162,8 @@ def compute_mrr(z, edge_index, edge_type): eval_edge_type = torch.full_like(head, fill_value=rel) out = model.decode(z, eval_edge_index, eval_edge_type) - perm = out.argsort(descending=True) - rank = int((perm == 0).nonzero(as_tuple=False).view(-1)[0]) - ranks.append(rank + 1) + rank = compute_rank(out) + ranks.append(rank) return (1. / torch.tensor(ranks, dtype=torch.float)).mean() From b5293951aca23950c0754c3da2f6342001420440 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 May 2022 20:11:35 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/rgcn_link_pred.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/rgcn_link_pred.py b/examples/rgcn_link_pred.py index df1ab9610f62..43c2332f8008 100644 --- a/examples/rgcn_link_pred.py +++ b/examples/rgcn_link_pred.py @@ -112,6 +112,7 @@ def test(): return valid_mrr, test_mrr + @torch.no_grad() def compute_rank(ranks): # fair ranking prediction as the average of optimistic and pessimistic ranking From c9825d34ba9ca7960120352f74553d704e420103 Mon Sep 17 00:00:00 2001 From: migalkin Date: Fri, 20 May 2022 16:25:13 -0400 Subject: [PATCH 3/3] pleasing PEP8 --- CHANGELOG.md | 1 + examples/rgcn_link_pred.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b31ee021aaa..01e0f6184a0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Fixed the ranking protocol bug in the RGCN link prediction example ([#4688](https://github.com/pyg-team/pytorch_geometric/pull/4688)) - Math support in Markdown ([#4683](https://github.com/pyg-team/pytorch_geometric/pull/4683)) - Allow for `setter` properties in `Data` ([#4682](https://github.com/pyg-team/pytorch_geometric/pull/4682), [#4686](https://github.com/pyg-team/pytorch_geometric/pull/4686)) - Allow for optional `edge_weight` in `GCN2Conv` ([#4670](https://github.com/pyg-team/pytorch_geometric/pull/4670)) diff --git a/examples/rgcn_link_pred.py b/examples/rgcn_link_pred.py index df1ab9610f62..3786007b5e24 100644 --- a/examples/rgcn_link_pred.py +++ b/examples/rgcn_link_pred.py @@ -114,7 +114,8 @@ def test(): @torch.no_grad() def compute_rank(ranks): - # fair ranking prediction as the average of optimistic and pessimistic ranking + # fair ranking prediction as the average + # of optimistic and pessimistic ranking true = ranks[0] optimistic = (ranks > true).sum() + 1 pessimistic = (ranks >= true).sum()