diff --git a/torch_geometric/nn/aggr/basic.py b/torch_geometric/nn/aggr/basic.py index d9370e2da27e..b30ba5dee8a7 100644 --- a/torch_geometric/nn/aggr/basic.py +++ b/torch_geometric/nn/aggr/basic.py @@ -101,9 +101,12 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: mean = self.reduce(x, index, ptr, dim_size, dim, reduce='mean') - x2 = x.detach() * x.detach() if self.semi_grad else x * x - mean_2 = self.reduce(x2, index, ptr, dim_size, dim, 'mean') - return mean_2 - mean * mean + if self.semi_grad: + with torch.no_grad(): + mean2 = self.reduce(x * x, index, ptr, dim_size, dim, 'mean') + else: + mean2 = self.reduce(x * x, index, ptr, dim_size, dim, 'mean') + return mean2 - mean * mean class StdAggregation(Aggregation): @@ -200,9 +203,10 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, alpha = x * t if not self.learn and self.semi_grad: - alpha = alpha.detach() - alpha = softmax(alpha, index, ptr, dim_size, dim) - alpha = softmax(alpha, index, ptr, dim_size, dim) + with torch.no_grad(): + alpha = softmax(alpha, index, ptr, dim_size, dim) + else: + alpha = softmax(alpha, index, ptr, dim_size, dim) return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum') def __repr__(self) -> str: diff --git a/torch_geometric/nn/aggr/fused.py b/torch_geometric/nn/aggr/fused.py index f7b629daab18..abf74691b05a 100644 --- a/torch_geometric/nn/aggr/fused.py +++ b/torch_geometric/nn/aggr/fused.py @@ -251,11 +251,17 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, # `include_self=True` + manual masking leads to faster runtime: out = x.new_full((dim_size, num_feats), fill_value) - src = x if reduce == 'pow_sum': reduce = 'sum' - src = x.detach() * x.detach() if self.semi_grad else x * x - out.scatter_reduce_(0, index, src, reduce, include_self=True) + if self.semi_grad: + with torch.no_grad(): + out.scatter_reduce_(0, index, x * x, reduce, + include_self=True) + else: + out.scatter_reduce_(0, index, x * x, reduce, + include_self=True) + else: + out.scatter_reduce_(0, index, x, reduce, include_self=True) if fill_value != 0.0: assert mask is not None