-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding CCC #227
adding CCC #227
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@ninamiolane I think I addressed that issue, but there seems still to be differences between my local env and this one. Tests on my local env are passing. |
Tests need to pass on github. This seems to be a problem with mypy: you can click on "Details" next to the failing tests to see the exact logs. |
"""Higher-Order Attentional NN Layer for Mesh Classification.""" | ||
|
||
|
||
from typing import List |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't use typing.List
anymore, it's deprecated. Use the built-in list
.
target_out_channels: int, | ||
negative_slope: float = 0.2, | ||
softmax: bool = False, | ||
update_func: str = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value None
is invalid for a parameter of type str
.
update_func: str = None, | |
update_func: str | None = None, |
|
||
def update( | ||
self, message_on_source: torch.Tensor, message_on_target: torch.Tensor | ||
) -> tuple[torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are returning two tensors. Please remember that tuple
is special to all other collections in that you specifically annotate each individual element of the tuple. See https://docs.python.org/3/library/typing.html#annotating-tuples
) -> tuple[torch.Tensor]: | |
) -> tuple[torch.Tensor, torch.Tensor]: |
|
||
def attention( | ||
self, s_message: torch.Tensor, t_message: torch.Tensor | ||
) -> tuple[torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> tuple[torch.Tensor]: | |
) -> tuple[torch.Tensor, torch.Tensor]: |
|
||
neighborhood_t_to_s_att = torch.sparse_coo_tensor( | ||
indices=neighborhood_t_to_s.indices(), | ||
values=t_to_s_attention.values() * neighborhood_t_to_s.values(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the types, t_to_s_attention
and neighbourhood_t_to_s
are tensors, which do not have a values
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally right! Suggested to change types to torch.sparse.FloatTensor in attention method.
negative_slope: float = 0.2, | ||
softmax: bool = False, | ||
m_hop: int = 1, | ||
update_func: str = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update_func: str = None, | |
update_func: str | None = None, |
elif self.update_func == "relu": | ||
return torch.nn.functional.relu(message) | ||
elif self.update_func == "tanh": | ||
return torch.nn.functional.tanh(message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about if none of the cases match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
also elif is not needed after a return statement. You can use if.
# Create a torch.eye with the device of x_source | ||
result = torch.eye(x_source.shape[0], device=self.get_device()).to_sparse_coo() | ||
|
||
neighborhood = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't overwrite a variable with another type, just use a different name. This is hard to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rballeba this is confusing for me too. I tried to fix this but I am not sure what you are trying to do here in your implementation. Can you please fix this PR from here? this is supposed to be easy for you since you coded this layer originally. Thanks!
"""Higher-Order Attentional NN Layer for Mesh Classification.""" | ||
|
||
|
||
from typing import List |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from typing import List |
def update( | ||
self, message_on_source: torch.Tensor, message_on_target: torch.Tensor | ||
) -> tuple[torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def update( | |
self, message_on_source: torch.Tensor, message_on_target: torch.Tensor | |
) -> tuple[torch.Tensor]: | |
def update( | |
self, message_on_source: torch.Tensor, message_on_target: torch.Tensor | |
) -> tuple[torch.FloatTensor, torch.FloatTensor]: |
def attention( | ||
self, s_message: torch.Tensor, t_message: torch.Tensor | ||
) -> tuple[torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def attention( | |
self, s_message: torch.Tensor, t_message: torch.Tensor | |
) -> tuple[torch.Tensor]: | |
def attention( | |
self, s_message: torch.Tensor, t_message: torch.Tensor | |
) -> tuple[torch.sparse.FloatTensor, torch.sparse.FloatTensor]: |
def forward( | ||
self, x_source: torch.Tensor, x_target: torch.Tensor, neighborhood: torch.Tensor | ||
) -> tuple[torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def forward( | |
self, x_source: torch.Tensor, x_target: torch.Tensor, neighborhood: torch.Tensor | |
) -> tuple[torch.Tensor]: | |
def forward( | |
self, x_source: torch.Tensor, x_target: torch.Tensor, neighborhood: torch.Tensor | |
) -> tuple[torch.FloatTensor, torch.FloatTensor]: |
|
||
neighborhood_t_to_s_att = torch.sparse_coo_tensor( | ||
indices=neighborhood_t_to_s.indices(), | ||
values=t_to_s_attention.values() * neighborhood_t_to_s.values(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally right! Suggested to change types to torch.sparse.FloatTensor in attention method.
elif self.update_func == "relu": | ||
return torch.nn.functional.relu(message) | ||
elif self.update_func == "tanh": | ||
return torch.nn.functional.tanh(message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif self.update_func == "relu": | |
return torch.nn.functional.relu(message) | |
elif self.update_func == "tanh": | |
return torch.nn.functional.tanh(message) | |
if self.update_func == "relu": | |
return torch.nn.functional.relu(message) | |
if self.update_func == "tanh": | |
return torch.nn.functional.tanh(message) | |
else: | |
raise RuntimeError( | |
"Update function not recognized. Should be either sigmoid, " | |
"relu or tanh." | |
) |
|
||
def attention( | ||
self, message: torch.Tensor, A_p: torch.Tensor, a_p: torch.Tensor | ||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> torch.Tensor: | |
) -> torch.sparse.FloatTensor: |
|
||
def forward( | ||
self, x_source: torch.Tensor, neighborhood: torch.Tensor | ||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> torch.Tensor: | |
) -> torch.FloatTensor: |
result = torch.eye(x_source.shape[0], device=self.get_device()).to_sparse_coo() | ||
|
||
neighborhood = [ | ||
result := torch.sparse.mm(neighborhood, result) for _ in range(self.m_hop) | ||
] | ||
|
||
att = [ | ||
self.attention(m_p, A_p, a_p) | ||
for m_p, A_p, a_p in zip(message, neighborhood, self.att_weight) | ||
] | ||
|
||
def sparse_hadamard(A_p, att_p): | ||
return torch.sparse_coo_tensor( | ||
indices=A_p.indices(), | ||
values=att_p.values() * A_p.values(), | ||
size=A_p.shape, | ||
device=self.get_device(), | ||
) | ||
|
||
neighborhood = [ | ||
sparse_hadamard(A_p, att_p) for A_p, att_p in zip(neighborhood, att) | ||
] | ||
message = [torch.mm(n_p, m_p) for n_p, m_p in zip(neighborhood, message)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
result = torch.eye(x_source.shape[0], device=self.get_device()).to_sparse_coo() | |
neighborhood = [ | |
result := torch.sparse.mm(neighborhood, result) for _ in range(self.m_hop) | |
] | |
att = [ | |
self.attention(m_p, A_p, a_p) | |
for m_p, A_p, a_p in zip(message, neighborhood, self.att_weight) | |
] | |
def sparse_hadamard(A_p, att_p): | |
return torch.sparse_coo_tensor( | |
indices=A_p.indices(), | |
values=att_p.values() * A_p.values(), | |
size=A_p.shape, | |
device=self.get_device(), | |
) | |
neighborhood = [ | |
sparse_hadamard(A_p, att_p) for A_p, att_p in zip(neighborhood, att) | |
] | |
message = [torch.mm(n_p, m_p) for n_p, m_p in zip(neighborhood, message)] | |
result = torch.eye(x_source.shape[0], device=self.get_device()).to_sparse_coo() | |
neighborhood = [ | |
result := torch.sparse.mm(neighborhood, result) for _ in range(self.m_hop) | |
] | |
att = [ | |
self.attention(m_p, A_p, a_p) | |
for m_p, A_p, a_p in zip(message, neighborhood, self.att_weight) | |
] | |
def sparse_hadamard(A_p, att_p): | |
return torch.sparse_coo_tensor( | |
indices=A_p.indices(), | |
values=att_p.values() * A_p.values(), | |
size=A_p.shape, | |
device=self.get_device(), | |
) | |
neighborhood = [ | |
sparse_hadamard(A_p, att_p) for A_p, att_p in zip(neighborhood, att) | |
] | |
message = [torch.mm(n_p, m_p) for n_p, m_p in zip(neighborhood, message)] |
target_out_channels: int, | ||
negative_slope: float = 0.2, | ||
softmax: bool = False, | ||
update_func: str = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update_func: str = None, | |
update_func: str | None = None |
@ManuelLecha I cannot approve this PR before all tests pass. |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #227 +/- ##
==========================================
+ Coverage 96.36% 96.54% +0.18%
==========================================
Files 55 58 +3
Lines 2036 2230 +194
==========================================
+ Hits 1962 2153 +191
- Misses 74 77 +3
☔ View full report in Codecov by Sentry. |
adding CCC tutorial