Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Nov 23, 2022
1 parent e909b52 commit 3309f7d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
16 changes: 10 additions & 6 deletions torch_geometric/nn/aggr/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,12 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
mean = self.reduce(x, index, ptr, dim_size, dim, reduce='mean')
x2 = x.detach() * x.detach() if self.semi_grad else x * x
mean_2 = self.reduce(x2, index, ptr, dim_size, dim, 'mean')
return mean_2 - mean * mean
if self.semi_grad:
with torch.no_grad():
mean2 = self.reduce(x * x, index, ptr, dim_size, dim, 'mean')
else:
mean2 = self.reduce(x * x, index, ptr, dim_size, dim, 'mean')
return mean2 - mean * mean


class StdAggregation(Aggregation):
Expand Down Expand Up @@ -200,9 +203,10 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
alpha = x * t

if not self.learn and self.semi_grad:
alpha = alpha.detach()
alpha = softmax(alpha, index, ptr, dim_size, dim)
alpha = softmax(alpha, index, ptr, dim_size, dim)
with torch.no_grad():
alpha = softmax(alpha, index, ptr, dim_size, dim)
else:
alpha = softmax(alpha, index, ptr, dim_size, dim)
return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum')

def __repr__(self) -> str:
Expand Down
12 changes: 9 additions & 3 deletions torch_geometric/nn/aggr/fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,17 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
# `include_self=True` + manual masking leads to faster runtime:
out = x.new_full((dim_size, num_feats), fill_value)

src = x
if reduce == 'pow_sum':
reduce = 'sum'
src = x.detach() * x.detach() if self.semi_grad else x * x
out.scatter_reduce_(0, index, src, reduce, include_self=True)
if self.semi_grad:
with torch.no_grad():
out.scatter_reduce_(0, index, x * x, reduce,
include_self=True)
else:
out.scatter_reduce_(0, index, x * x, reduce,
include_self=True)
else:
out.scatter_reduce_(0, index, x, reduce, include_self=True)

if fill_value != 0.0:
assert mask is not None
Expand Down

0 comments on commit 3309f7d

Please sign in to comment.