diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 6c0cafbb489..fcd19b92ff7 100755 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -74,6 +74,7 @@ def is_deprecated(func_or_class): from oneflow._C import batch_matmul as bmm from oneflow._C import broadcast_like from oneflow._C import chunk +from oneflow._C import split from oneflow._C import sign from oneflow._C import sinh from oneflow._C import tan @@ -336,7 +337,6 @@ def atexit_hook(hook): from oneflow.nn.modules.slice import slice_update_op as slice_update from oneflow.nn.modules.slice import logical_slice_assign_op as logical_slice_assign from oneflow.nn.modules.sort import sort_op as sort -from oneflow.nn.modules.split import split_op as split from oneflow.nn.modules.eye import eye_op as eye from oneflow.nn.modules.tensor_buffer import gen_tensor_buffer from oneflow.nn.modules.tensor_buffer import ( diff --git a/python/oneflow/framework/docstr/__init__.py b/python/oneflow/framework/docstr/__init__.py index 2126bd54dfb..89bcba37b81 100644 --- a/python/oneflow/framework/docstr/__init__.py +++ b/python/oneflow/framework/docstr/__init__.py @@ -40,3 +40,4 @@ from .flatten import * from .chunk import * from .broadcast_like import * +from .split import * diff --git a/python/oneflow/nn/modules/split.py b/python/oneflow/framework/docstr/split.py similarity index 80% rename from python/oneflow/nn/modules/split.py rename to python/oneflow/framework/docstr/split.py index cf65fc1012e..7b69312aa08 100644 --- a/python/oneflow/nn/modules/split.py +++ b/python/oneflow/framework/docstr/split.py @@ -13,15 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ -from typing import Union, List -import numpy as np +import oneflow +from oneflow.framework.docstr.utils import add_docstr -import oneflow as flow -from oneflow.framework.tensor import Tensor, register_tensor_op - - -@register_tensor_op("split") -def split_op(x, split_size_or_sections: Union[int, List[int]], dim: int = 0): +add_docstr( + oneflow.split, """Splits the tensor into chunks. If `split_size_or_sections` is an integer type, then x will be split into equally sized chunks (if possible). @@ -50,11 +46,5 @@ def split_op(x, split_size_or_sections: Union[int, List[int]], dim: int = 0): [4, 5], [6, 7], [8, 9]], dtype=oneflow.int64)) - """ - return flow._C.split(x, split_size=split_size_or_sections, dim=dim) - - -if __name__ == "__main__": - import doctest - - doctest.testmod(raise_on_error=True) + """, +) diff --git a/python/oneflow/framework/docstr/tensor.py b/python/oneflow/framework/docstr/tensor.py index 15eb67ac674..f3a6cbd3b2b 100644 --- a/python/oneflow/framework/docstr/tensor.py +++ b/python/oneflow/framework/docstr/tensor.py @@ -373,6 +373,13 @@ """, ) +add_docstr( + oneflow.Tensor.split, + """ + See :func:`oneflow.split` + """, +) + add_docstr( oneflow.Tensor.cast, """ diff --git a/python/oneflow/framework/tensor.py b/python/oneflow/framework/tensor.py index fad70acdccb..3d73d6d60dc 100644 --- a/python/oneflow/framework/tensor.py +++ b/python/oneflow/framework/tensor.py @@ -588,6 +588,10 @@ def _chunk(self, chunks=None, dim=None): return flow._C.chunk(self, chunks, dim) +def _split(self, split_size_or_sections=None, dim=None): + return flow._C.split(self, split_size_or_sections, dim) + + def _all(self, dim=None, keepdim=False): return flow.all(self, dim, keepdim) @@ -879,6 +883,7 @@ def RegisterMethods(): Tensor.roll = _roll Tensor.bmm = _bmm Tensor.chunk = _chunk + Tensor.split = _split Tensor.squeeze = _squeeze Tensor.unfold = _unfold Tensor.narrow = _narrow diff --git a/python/oneflow/test/modules/test_split.py b/python/oneflow/test/modules/test_split.py index 82460db2c46..7b7f04215ce 100644 --- a/python/oneflow/test/modules/test_split.py +++ b/python/oneflow/test/modules/test_split.py @@ -30,8 +30,8 @@ def test_flow_split_with_random_data(test_case): k2 = random(2, 6) rand_dim = random(0, 3).to(int) device = random_device() - x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device) - res = torch.split(x, split_size_or_sections=2, dim=rand_dim) + x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) + res = torch.split(x, 2, dim=rand_dim) return torch.cat(res, rand_dim) @autotest(check_graph=False) @@ -40,8 +40,8 @@ def test_flow_split_sizes_with_random_data(test_case): k1 = 7 k2 = random(2, 6) device = random_device() - x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device) - res = torch.split(x, split_size_or_sections=[1, 2, 3, 1], dim=1) + x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) + res = torch.split(x, [1, 2, 3, 1], dim=1) return torch.cat(res, dim=1) @autotest(check_graph=False) @@ -50,8 +50,8 @@ def test_flow_split_sizes_neg_dim_with_random_data(test_case): k1 = 7 k2 = random(2, 6) device = random_device() - x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device) - res = torch.split(x, split_size_or_sections=[1, 2, 3, 1], dim=-2) + x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) + res = torch.split(x, [1, 2, 3, 1], dim=-2) return torch.cat(res, dim=1) diff --git a/python/oneflow/test/tensor/test_tensor.py b/python/oneflow/test/tensor/test_tensor.py index 2a8934fbe16..853140b24b4 100644 --- a/python/oneflow/test/tensor/test_tensor.py +++ b/python/oneflow/test/tensor/test_tensor.py @@ -1623,6 +1623,29 @@ def test_tensor_bmm(test_case): of_out = input1.bmm(input2) return of_out + @flow.unittest.skip_unless_1n1d() + @autotest(check_graph=False) + def test_tensor_split(test_case): + k0 = random(2, 6) + k1 = random(2, 6) + k2 = random(2, 6) + rand_dim = random(0, 3).to(int) + device = random_device() + x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) + res = x.split(2, dim=rand_dim) + return torch.cat(res, rand_dim) + + @flow.unittest.skip_unless_1n1d() + @autotest(check_graph=False) + def test_tensor_split_sizes(test_case): + k0 = random(2, 6) + k1 = 7 + k2 = random(2, 6) + device = random_device() + x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) + res = x.split([1, 2, 3, 1], dim=-2) + return torch.cat(res, dim=1) + if __name__ == "__main__": unittest.main()