Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🩹 👾 Fix ERMLP functional form #444

Merged
merged 10 commits into from
Jun 10, 2021
Merged

🩹 👾 Fix ERMLP functional form #444

merged 10 commits into from
Jun 10, 2021

Conversation

mberr
Copy link
Member

@mberr mberr commented May 12, 2021

This PR fixes an issue with the ER-MLP functional form.

I encountered the issue trying to run inference on a single element batch

result = pipeline(
    interaction="ermlp",
    interaction_kwargs=dict(embedding_dim=32, hidden_dim=32),
    dimensions=dict(d=32),
    dataset="nations",
    loss="bcewithlogits",
    training_kwargs=dict(num_epochs=1),
)
result.model.predict_t(result.training.mapped_triples[:1, 1:])

@cthoyt
Copy link
Member

cthoyt commented May 12, 2021

Can you encode the failing functionality in a test?

@mberr
Copy link
Member Author

mberr commented May 12, 2021

Can you encode the failing functionality in a test?

The error in the implementation was that SizeInformation extracted the wrong size information.

In nn.functional we always use the "more canonical" form (b, nh, nr, nt, *), whereas the score_* methods of modules only use the canonical form (b, n*, *).

For new-style interaction functions, we only test the score_* methods of the interaction module. The ERModel does not use them but always directly goes to the generic score one (without suffix).

I can try to come up with a test, but right now I just stumbled across this error and pushed my local fix here without having time for a deeper investigation of the root cause.

@cthoyt cthoyt added the bug Something isn't working label May 20, 2021
@mberr
Copy link
Member Author

mberr commented Jun 10, 2021

A smaller snippet for reproducing the error without any training:

from pykeen.models.resolve import make_model_cls
from pykeen.triples.generation import generate_triples_factory

triples_factory = generate_triples_factory()
model_cls = make_model_cls(
    interaction="ermlp", 
    dimensions=dict(d=3), 
    interaction_kwargs=dict(
        embedding_dim=3, 
        hidden_dim=5,
    ),
)
model = model_cls(triples_factory=triples_factory)
model.predict_t(triples_factory.mapped_triples[:1, :-1])

@mberr
Copy link
Member Author

mberr commented Jun 10, 2021

An even smaller one

import torch
from pykeen.nn.modules import ERMLPInteraction

num = 5
dim = 3
batch_size = 1
interaction = ERMLPInteraction(embedding_dim=dim, hidden_dim=1)
interaction.score_t(h=torch.rand(batch_size, dim), r=torch.rand(batch_size, dim), all_entities=torch.rand(num, dim))

@mberr
Copy link
Member Author

mberr commented Jun 10, 2021

This is strange, since we have

class ERMLPTests(cases.InteractionTestCase):

and

pykeen/tests/cases.py

Lines 445 to 453 in 9b7adc6

def test_score_t(self):
"""Test score_t."""
h, r, t = self._get_hrt(
(self.batch_size,),
(self.batch_size,),
(self.num_entities,),
)
scores = self.instance.score_t(h=h, r=r, all_entities=t)
self._check_scores(scores=scores, exp_shape=(self.batch_size, self.num_entities))

which essentially should do the same.

@mberr
Copy link
Member Author

mberr commented Jun 10, 2021

batch_size=1

seems to be what causes the problem.

@mberr
Copy link
Member Author

mberr commented Jun 10, 2021

@cthoyt 53ed846 this catches the old problem. Maybe we want this for the other score_* functions, too?

@cthoyt
Copy link
Member

cthoyt commented Jun 10, 2021

@mberr yes let’s implement this check for all models/score_*

@mberr
Copy link
Member Author

mberr commented Jun 10, 2021

@cthoyt there seems to be a new version of mypy, since there are now some mypy errors in non-modified code parts, cf. https://github.com/pykeen/pykeen/pull/444/checks?check_run_id=2793677666#step:9:12

@cthoyt cthoyt changed the title Fix ERMLP functional form 🩹 👾 Fix ERMLP functional form Jun 10, 2021
Copy link
Member

@cthoyt cthoyt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for me. Squash+merge when ready

@mberr mberr merged commit 0d396bb into master Jun 10, 2021
@mberr mberr deleted the fix-ermlp-functional-form branch June 10, 2021 19:08
@cthoyt
Copy link
Member

cthoyt commented Jun 10, 2021

🦜 💃 ⏩

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants