Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 4, 2024
1 parent 88af85e commit 562738b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
17 changes: 15 additions & 2 deletions test/utils/test_softmax.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import pytest
import torch

import torch_geometric.typing
from torch_geometric.profile import benchmark
from torch_geometric.utils import softmax

CALCULATION_VIA_PTR_AVAILABLE = (torch_geometric.typing.WITH_SOFTMAX
or torch_geometric.typing.WITH_TORCH_SCATTER)


def test_softmax():
src = torch.tensor([1., 1., 1., 1.])
Expand Down Expand Up @@ -53,11 +58,19 @@ def test_softmax_dim():

src = torch.randn(4, 4)
assert torch.allclose(softmax(src, index, dim=-1), src.softmax(dim=-1))
assert torch.allclose(softmax(src, ptr=ptr, dim=-1), src.softmax(-1))
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(softmax(src, ptr=ptr, dim=-1), src.softmax(-1))
else:
with pytest.raises(ImportError, match="requires the 'torch-scatter'"):
softmax(src, ptr=ptr, dim=-1)

src = torch.randn(4, 4, 16)
assert torch.allclose(softmax(src, index, dim=1), src.softmax(dim=1))
assert torch.allclose(softmax(src, ptr=ptr, dim=1), src.softmax(dim=1))
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(softmax(src, ptr=ptr, dim=1), src.softmax(dim=1))
else:
with pytest.raises(ImportError, match="requires the 'torch-scatter'"):
softmax(src, ptr=ptr, dim=1)


if __name__ == '__main__':
Expand Down
6 changes: 5 additions & 1 deletion torch_geometric/utils/_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def segment(src: Tensor, ptr: Tensor, reduce: str = 'sum') -> Tensor:
if not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling():
return _torch_segment(src, ptr, reduce)

if torch_geometric.typing.WITH_PT20 and src.is_cuda and reduce == 'mean':
if (ptr.dim() == 1 and torch_geometric.typing.WITH_PT20 and src.is_cuda
and reduce == 'mean'):
return _torch_segment(src, ptr, reduce)

return torch_scatter.segment_csr(src, ptr, reduce=reduce)
Expand All @@ -34,6 +35,9 @@ def segment(src: Tensor, ptr: Tensor, reduce: str = 'sum') -> Tensor:
def _torch_segment(src: Tensor, ptr: Tensor, reduce: str = 'sum') -> Tensor:
if not torch_geometric.typing.WITH_PT20:
raise ImportError("'segment' requires the 'torch-scatter' package")
if ptr.dim() > 1:
raise ImportError("'segment' in an arbitrary dimension "

Check warning on line 39 in torch_geometric/utils/_segment.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/utils/_segment.py#L39

Added line #L39 was not covered by tests
"requires the 'torch-scatter' package")

if reduce == 'min' or reduce == 'max':
reduce = f'a{reduce}' # `amin` or `amax`
Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/utils/_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def softmax(
and not is_compiling()): # pragma: no cover
return pyg_lib.ops.softmax_csr(src, ptr, dim)

if ptr is not None:
if (ptr is not None and
(ptr.dim() == 1 or (ptr.dim() > 1 and index is None) or
(torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()))):

dim = dim + src.dim() if dim < 0 else dim
size = ([1] * dim) + [-1]
count = ptr[1:] - ptr[:-1]
Expand Down

0 comments on commit 562738b

Please sign in to comment.