Skip to content

Commit

Permalink
feat: support for PyTorch tensors in cc3d.connected_components
Browse files Browse the repository at this point in the history
  • Loading branch information
william-silversmith committed Jun 17, 2024
1 parent 447bbb7 commit bb85ed6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
19 changes: 19 additions & 0 deletions automated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,25 @@ def test_color_connectivity_graph_26():
cc_labels = cc3d.color_connectivity_graph(vcg, connectivity=26)
assert np.all(cc_labels == 1)

def test_pytorch_integration_ccl_doesnt_crash():
torch = pytest.importorskip("torch")

labels = torch.from_numpy(np.zeros([100,100,100], dtype=np.uint16))

out = cc3d.connected_components(labels)

assert isinstance(out, torch.Tensor)
assert torch.all(out == labels)













9 changes: 9 additions & 0 deletions cc3d.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,11 @@ def connected_components(
if return_N: (OUT, N)
else: OUT
"""
is_torch = hasattr(data, "cpu")
if is_torch:
# don't need to call .detach() b/c its read-only
data = data.cpu().numpy()
cdef int dims = len(data.shape)
if dims not in (1,2,3):
raise DimensionError("Only 1D, 2D, and 3D arrays supported. Got: " + str(dims))
Expand Down Expand Up @@ -569,6 +574,10 @@ def connected_components(
out_labels = _final_reshape(out_labels, sx, sy, sz, dims, order)
if is_torch:
import torch
out_labels = torch.from_numpy(out_labels)
if return_N:
return (out_labels, N)
return out_labels
Expand Down

0 comments on commit bb85ed6

Please sign in to comment.