diff --git a/docs/test/test_documentation_examples.py b/docs/test/test_documentation_examples.py index 3cf7ac3b30..3931bdfd15 100644 --- a/docs/test/test_documentation_examples.py +++ b/docs/test/test_documentation_examples.py @@ -6,6 +6,11 @@ import sys import unittest +import torch + +from onnxscript import evaluator +from onnxscript.function_libs.torch_lib import ops + class TestDocumentationExample(unittest.TestCase): def do_test_folder(self, folder): @@ -22,8 +27,7 @@ def do_test_folder(self, folder): with subprocess.Popen( cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) as p: - res = p.communicate() - _, err = res + _, err = p.communicate() st = err.decode("ascii", errors="ignore") if len(st) > 0 and "Traceback" in st: raise RuntimeError( # pylint: disable=W0707 @@ -54,5 +58,18 @@ def test(*relpath): test("..", "..", "docs", "tutorial", "rewriter", "examples") +def test_unbind_matches_torch(): + x_torch = torch.randn(3, 4) + y_torch = torch.unbind(x_torch, dim=1) + + x_np = x_torch.detach().cpu().numpy() + eager = evaluator.default() + y_onnx = eager.eval_function(ops.core.aten_unbind, (x_np,), {"dim": 1}) + + assert len(y_torch) == len(y_onnx) + for a, b in zip(y_torch, y_onnx): + assert a.shape == tuple(b.shape), f"Shape mismatch: {a.shape} vs {b.shape}" + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ab992e0580..1b68b80b38 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8618,10 +8618,13 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: @torch_op("aten::unbind.int") def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: - """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" + """unbind(Tensor self, int dim=0) -> Tensor[] - split_sizes = op.Constant(value_int=1) - return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False) + Splits a tensor into multiple tensors along the given dimension without keeping the dimension. + Matches the behavior of torch.unbind. + """ + split_size = op.Constant(value_int=1) + return op.SplitToSequence(self, split_size, axis=dim, keepdims=False) @torch_op("aten::unflatten.int", trace_only=True)