Skip to content

Commit

Permalink
Migrate split python layer to functor (#7030)
Browse files Browse the repository at this point in the history
* Migrate split python layer to functor

* modify dim

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
lixiang007666 and oneflow-ci-bot authored Dec 15, 2021
1 parent 52b6560 commit 31a64d3
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 23 deletions.
2 changes: 1 addition & 1 deletion python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def is_deprecated(func_or_class):
from oneflow._C import batch_matmul as bmm
from oneflow._C import broadcast_like
from oneflow._C import chunk
from oneflow._C import split
from oneflow._C import sign
from oneflow._C import sinh
from oneflow._C import tan
Expand Down Expand Up @@ -336,7 +337,6 @@ def atexit_hook(hook):
from oneflow.nn.modules.slice import slice_update_op as slice_update
from oneflow.nn.modules.slice import logical_slice_assign_op as logical_slice_assign
from oneflow.nn.modules.sort import sort_op as sort
from oneflow.nn.modules.split import split_op as split
from oneflow.nn.modules.eye import eye_op as eye
from oneflow.nn.modules.tensor_buffer import gen_tensor_buffer
from oneflow.nn.modules.tensor_buffer import (
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/framework/docstr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@
from .flatten import *
from .chunk import *
from .broadcast_like import *
from .split import *
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,11 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Union, List
import numpy as np
import oneflow
from oneflow.framework.docstr.utils import add_docstr

import oneflow as flow
from oneflow.framework.tensor import Tensor, register_tensor_op


@register_tensor_op("split")
def split_op(x, split_size_or_sections: Union[int, List[int]], dim: int = 0):
add_docstr(
oneflow.split,
"""Splits the tensor into chunks.
If `split_size_or_sections` is an integer type, then x will be split into equally sized chunks (if possible).
Expand Down Expand Up @@ -50,11 +46,5 @@ def split_op(x, split_size_or_sections: Union[int, List[int]], dim: int = 0):
[4, 5],
[6, 7],
[8, 9]], dtype=oneflow.int64))
"""
return flow._C.split(x, split_size=split_size_or_sections, dim=dim)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
""",
)
7 changes: 7 additions & 0 deletions python/oneflow/framework/docstr/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,13 @@
""",
)

add_docstr(
oneflow.Tensor.split,
"""
See :func:`oneflow.split`
""",
)

add_docstr(
oneflow.Tensor.cast,
"""
Expand Down
5 changes: 5 additions & 0 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,10 @@ def _chunk(self, chunks=None, dim=None):
return flow._C.chunk(self, chunks, dim)


def _split(self, split_size_or_sections=None, dim=None):
return flow._C.split(self, split_size_or_sections, dim)


def _all(self, dim=None, keepdim=False):
return flow.all(self, dim, keepdim)

Expand Down Expand Up @@ -879,6 +883,7 @@ def RegisterMethods():
Tensor.roll = _roll
Tensor.bmm = _bmm
Tensor.chunk = _chunk
Tensor.split = _split
Tensor.squeeze = _squeeze
Tensor.unfold = _unfold
Tensor.narrow = _narrow
Expand Down
12 changes: 6 additions & 6 deletions python/oneflow/test/modules/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def test_flow_split_with_random_data(test_case):
k2 = random(2, 6)
rand_dim = random(0, 3).to(int)
device = random_device()
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device)
res = torch.split(x, split_size_or_sections=2, dim=rand_dim)
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device)
res = torch.split(x, 2, dim=rand_dim)
return torch.cat(res, rand_dim)

@autotest(check_graph=False)
Expand All @@ -40,8 +40,8 @@ def test_flow_split_sizes_with_random_data(test_case):
k1 = 7
k2 = random(2, 6)
device = random_device()
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device)
res = torch.split(x, split_size_or_sections=[1, 2, 3, 1], dim=1)
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device)
res = torch.split(x, [1, 2, 3, 1], dim=1)
return torch.cat(res, dim=1)

@autotest(check_graph=False)
Expand All @@ -50,8 +50,8 @@ def test_flow_split_sizes_neg_dim_with_random_data(test_case):
k1 = 7
k2 = random(2, 6)
device = random_device()
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device)
res = torch.split(x, split_size_or_sections=[1, 2, 3, 1], dim=-2)
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device)
res = torch.split(x, [1, 2, 3, 1], dim=-2)
return torch.cat(res, dim=1)


Expand Down
23 changes: 23 additions & 0 deletions python/oneflow/test/tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,29 @@ def test_tensor_bmm(test_case):
of_out = input1.bmm(input2)
return of_out

@flow.unittest.skip_unless_1n1d()
@autotest(check_graph=False)
def test_tensor_split(test_case):
k0 = random(2, 6)
k1 = random(2, 6)
k2 = random(2, 6)
rand_dim = random(0, 3).to(int)
device = random_device()
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device)
res = x.split(2, dim=rand_dim)
return torch.cat(res, rand_dim)

@flow.unittest.skip_unless_1n1d()
@autotest(check_graph=False)
def test_tensor_split_sizes(test_case):
k0 = random(2, 6)
k1 = 7
k2 = random(2, 6)
device = random_device()
x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device)
res = x.split([1, 2, 3, 1], dim=-2)
return torch.cat(res, dim=1)


if __name__ == "__main__":
unittest.main()

0 comments on commit 31a64d3

Please sign in to comment.