-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add MulAggregation and MultiAggregation * Fix import issue * Support torch_geometric.nn.aggr package, note: jit errors to fix * Add tests for MulAggregation, MultiAggregation, aggregation_resolver and message_passing interface * Formatting * Fix __repr for gen aggrs * Move resolver * Fix test for MulAggregation * Add test for new mp interface * Add test for MultiAggregation * Minor fix * Add warming for MulAggregation with 'ptr' * Resolve aggr to Aggregation module, remove aggrs logic * changelog * Fix mul aggregation * update * update * update * update * reset Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
- Loading branch information
Showing
8 changed files
with
144 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
|
||
from torch_geometric.nn import MultiAggregation | ||
|
||
|
||
def test_multi_aggr(): | ||
x = torch.randn(6, 16) | ||
index = torch.tensor([0, 0, 1, 1, 1, 2]) | ||
ptr = torch.tensor([0, 2, 5, 6]) | ||
|
||
aggrs = ['mean', 'sum', 'max'] | ||
aggr = MultiAggregation(aggrs) | ||
assert str(aggr) == ('MultiAggregation([\n' | ||
' MeanAggregation(),\n' | ||
' SumAggregation(),\n' | ||
' MaxAggregation()\n' | ||
'])') | ||
|
||
out = aggr(x, index) | ||
assert out.size() == (3, len(aggrs) * x.size(1)) | ||
assert torch.allclose(out, aggr(x, ptr=ptr)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torch_geometric.nn.aggr import Aggregation | ||
from torch_geometric.nn.resolver import aggregation_resolver | ||
|
||
|
||
class MultiAggregation(Aggregation): | ||
def __init__(self, aggrs: List[Union[Aggregation, str]]): | ||
super().__init__() | ||
|
||
if not isinstance(aggrs, (list, tuple)): | ||
raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " | ||
f"be a list or tuple (got '{type(aggrs)}')") | ||
|
||
if len(aggrs) == 0: | ||
raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should " | ||
f"not be empty") | ||
|
||
self.aggrs = [aggregation_resolver(aggr) for aggr in aggrs] | ||
|
||
def forward(self, x: Tensor, index: Optional[Tensor] = None, *, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
outs = [] | ||
for aggr in self.aggrs: | ||
outs.append(aggr(x, index, ptr=ptr, dim_size=dim_size, dim=dim)) | ||
return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0] | ||
|
||
def __repr__(self) -> str: | ||
args = [f' {aggr}' for aggr in self.aggrs] | ||
return '{}([\n{}\n])'.format(self.__class__.__name__, ',\n'.join(args)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters