-
-
Notifications
You must be signed in to change notification settings - Fork 195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🧨 💎 Add literal interaction and outline for tests #245
Conversation
Trigger CI
Since it needs to auto-append the literal representation, it can't do magic stuff in the middle
input_dropout: float = 0.0, | ||
): | ||
linear = nn.Linear(embedding_dim + num_of_literals, embedding_dim) | ||
dropout = nn.Dropout(input_dropout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mberr note the dropout comes after the linear later for distmult but before for complex. was this correct?
…tion-modules # Conflicts: # src/pykeen/models/multimodal/complex_literal.py # src/pykeen/models/multimodal/distmult_literal.py
embedding_dim=embedding_dim, | ||
initializer=nn.init.xavier_normal_, | ||
dtype=torch.complex64, | ||
# TODO: verify |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs to be written up a bit differently - no regularization was reported for this model, right? so it's time to remove it
Maybe, we should use a mixin here? Maybe something like this class EntityCombinationMixin(Interaction):
def __init__(self, combination: Callable[..., torch.Tensor], *args, **kwargs):
super().__init__(*args, **kwargs)
self.combination = combination
def forward(self, h, r, t):
h = self.combination(h)
t = self.combination(t)
return super().forward(h=h, r=r, t=t)
class LinearProjectionCombination(EntityCombinationMixin):
def __init__(self, embedding_dim, num_of_literals,, *args, **kwargs):
self.projection = nn.Linear(embedding_dim + num_of_literals, embedding_dim)
super().__init__(combination=lambda x: self.projection(torch.cat(x, dim=-1)), *args, **kwargs)
class DistMultLiteral(EntityCombinationMixin, DistMultInteraction):
pass |
cool idea, but let's make sure the typing is done right. The |
I am not sure whether the |
@mberr you're thinking convolution too, right? I'm not the only one thinking that would be cool (but probably ultimately not helpful 😆 ) |
@mberr right now there's an issue with the complEx one
Also remove dead class MultimodalModel
Trigger CI
super().__init__() | ||
self.base = interaction_resolver.make(base, base_kwargs) | ||
self.combination = combination | ||
self.entity_shape = tuple(self.base.entity_shape) + ("e",) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to make sure that "e"
is unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, there needs to be an extra dimension floating around here that gets added in by the literal model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added more docs for this in c7b9bc1
scores_f = self.cls.func(**kwargs).view(-1) | ||
else: | ||
kwargs = dict(h=h, r=r, t=t) | ||
scores_f = self.instance(h=h, r=r, t=t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for view(-1)
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it works without it... sooooooo... you tell me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
besides, looks good to me! Great to have them integrated into testing.
Closes #244. This PR is part of the break-up effort on #107.
This PR implements literal interaction modules, which wrap another interaction module as well as a "combination" function, which takes the representations for the entities and representations for literals and creates a combined representation (implementation up to user - the simplest is to concatenate them).
There's a bit of goofiness to the typing since it's hard to express types dynamically - this makes it difficult to think about any other kind of interaction function besides simple ones that give single embeddings for head, relation, and tail.
Before Merge
Combination
class to take the unconcatenated representations?