Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate split python layer to functor #7030

Merged
merged 7 commits into from
Dec 15, 2021
2 changes: 1 addition & 1 deletion python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def is_deprecated(func_or_class):
from oneflow._C import batch_matmul as bmm
from oneflow._C import broadcast_like
from oneflow._C import chunk
from oneflow._C import split
from oneflow._C import sign
from oneflow._C import sinh
from oneflow._C import tan
Expand Down Expand Up @@ -336,7 +337,6 @@ def atexit_hook(hook):
from oneflow.nn.modules.slice import slice_update_op as slice_update
from oneflow.nn.modules.slice import logical_slice_assign_op as logical_slice_assign
from oneflow.nn.modules.sort import sort_op as sort
from oneflow.nn.modules.split import split_op as split
from oneflow.nn.modules.eye import eye_op as eye
from oneflow.nn.modules.tensor_buffer import gen_tensor_buffer
from oneflow.nn.modules.tensor_buffer import (
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/framework/docstr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@
from .flatten import *
from .chunk import *
from .broadcast_like import *
from .split import *
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,11 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Union, List
import numpy as np
import oneflow
from oneflow.framework.docstr.utils import add_docstr

import oneflow as flow
from oneflow.framework.tensor import Tensor, register_tensor_op


@register_tensor_op("split")
def split_op(x, split_size_or_sections: Union[int, List[int]], dim: int = 0):
add_docstr(
oneflow.split,
"""Splits the tensor into chunks.

If `split_size_or_sections` is an integer type, then x will be split into equally sized chunks (if possible).
Expand Down Expand Up @@ -50,11 +46,5 @@ def split_op(x, split_size_or_sections: Union[int, List[int]], dim: int = 0):
[4, 5],
[6, 7],
[8, 9]], dtype=oneflow.int64))
"""
return flow._C.split(x, split_size=split_size_or_sections, dim=dim)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
""",
)
7 changes: 7 additions & 0 deletions python/oneflow/framework/docstr/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,13 @@
""",
)

add_docstr(
oneflow.Tensor.split,
"""
See :func:`oneflow.split`
""",
)

add_docstr(
oneflow.Tensor.cast,
"""
Expand Down
5 changes: 5 additions & 0 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,10 @@ def _chunk(self, chunks=None, dim=None):
return flow._C.chunk(self, chunks, dim)


def _split(self, split_size_or_sections=None, dim=None):
return flow._C.split(self, split_size_or_sections, dim)


def _all(self, dim=None, keepdim=False):
return flow.all(self, dim, keepdim)

Expand Down Expand Up @@ -879,6 +883,7 @@ def RegisterMethods():
Tensor.roll = _roll
Tensor.bmm = _bmm
Tensor.chunk = _chunk
Tensor.split = _split
Tensor.squeeze = _squeeze
Tensor.unfold = _unfold
Tensor.narrow = _narrow
Expand Down
6 changes: 3 additions & 3 deletions python/oneflow/test/modules/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_flow_split_with_random_data(test_case):
rand_dim = random(0, 3).to(int)
device = random_device()
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndim=4才对,这里越界了。下面的单测都有问题,得改改。

res = torch.split(x, split_size_or_sections=2, dim=rand_dim)
res = torch.split(x, 2, dim=rand_dim)
return torch.cat(res, rand_dim)

@autotest(check_graph=False)
Expand All @@ -41,7 +41,7 @@ def test_flow_split_sizes_with_random_data(test_case):
k2 = random(2, 6)
device = random_device()
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device)
res = torch.split(x, split_size_or_sections=[1, 2, 3, 1], dim=1)
res = torch.split(x, [1, 2, 3, 1], dim=1)
return torch.cat(res, dim=1)

@autotest(check_graph=False)
Expand All @@ -51,7 +51,7 @@ def test_flow_split_sizes_neg_dim_with_random_data(test_case):
k2 = random(2, 6)
device = random_device()
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device)
res = torch.split(x, split_size_or_sections=[1, 2, 3, 1], dim=-2)
res = torch.split(x, [1, 2, 3, 1], dim=-2)
return torch.cat(res, dim=1)


Expand Down
23 changes: 23 additions & 0 deletions python/oneflow/test/tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,29 @@ def test_tensor_bmm(test_case):
of_out = input1.bmm(input2)
return of_out

@flow.unittest.skip_unless_1n1d()
@autotest(check_graph=False)
def test_tensor_split(test_case):
k0 = random(2, 6)
k1 = random(2, 6)
k2 = random(2, 6)
rand_dim = random(0, 3).to(int)
device = random_device()
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device)
res = x.split(2, dim=rand_dim)
return torch.cat(res, rand_dim)

@flow.unittest.skip_unless_1n1d()
@autotest(check_graph=False)
def test_tensor_split_sizes(test_case):
k0 = random(2, 6)
k1 = 7
k2 = random(2, 6)
device = random_device()
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device)
res = x.split([1, 2, 3, 1], dim=-2)
return torch.cat(res, dim=1)


if __name__ == "__main__":
unittest.main()