Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate chunk python layer to functor #6983

Merged
merged 16 commits into from
Dec 11, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1688,6 +1688,12 @@
]
bind_python: True

- name: "chunk"
signature: [
"TensorTuple (Tensor x, Int64 chunks, Int64 dim=0) => Chunk",
]
bind_python: True

- name: "split_like"
signature: "TensorTuple (Tensor x, TensorTuple like, Int64 axis) => SplitLike"
bind_python: True
Expand Down
28 changes: 28 additions & 0 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,33 @@ class SplitFunctor {
}
};

class ChunkFunctor {
public:
ChunkFunctor() {}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& chunks,
const int64_t& dim) const {
int64_t split_size = x->shape()->At(dim) / chunks;
if ((split_size * chunks) != x->shape()->At(dim)) {
split_size = split_size * (chunks - 1) + (x->shape()->At(dim) - split_size * (chunks - 1));
}
int64_t axis = dim;
if (axis < 0) { axis += x->ndim(); }
CHECK_OR_RETURN(axis >= 0 && axis < x->ndim())
BBuf marked this conversation as resolved.
Show resolved Hide resolved
<< "The dim " << dim << " is out of bound " << x->ndim() - 1;
CHECK_GE_OR_RETURN(split_size, 0)
<< "split expects split_size be non-negative, but got split_size=" << split_size;
int64_t dim_size = x->shape()->At(axis);
int64_t num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
TensorTuple splits(num_splits);
int64_t last_split_size = split_size - (split_size * num_splits - dim_size);
for (int i = 0; i < num_splits; ++i) {
int64_t length = i < num_splits - 1 ? split_size : last_split_size;
splits[i] = JUST(Narrow(x, axis, i * split_size, length));
}
return splits;
}
};

class SplitLikeFunctor {
public:
SplitLikeFunctor() {
Expand Down Expand Up @@ -2326,6 +2353,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::ReduceSumLikeFunctor>("ReduceSumLike");
m.add_functor<impl::BroadcastReduceSumLikeFunctor>("BroadcastReduceSumLike");
m.add_functor<impl::SplitFunctor>("Split");
m.add_functor<impl::ChunkFunctor>("Chunk");
m.add_functor<impl::SplitLikeFunctor>("SplitLike");
m.add_functor<impl::SplitWithSizeFunctor>("SplitWithSize");
m.add_functor<impl::BatchGatherFunctor>("BatchGather");
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def is_deprecated(func_or_class):
from oneflow._C import atanh
from oneflow._C import atanh as arctanh
from oneflow._C import batch_matmul as bmm
from oneflow._C import chunk
from oneflow._C import sign
from oneflow._C import sinh
from oneflow._C import tan
Expand Down Expand Up @@ -284,7 +285,6 @@ def atexit_hook(hook):
from oneflow.nn.modules.argsort import argsort_op as argsort
from oneflow.nn.modules.argwhere import argwhere_op as argwhere
from oneflow.nn.modules.broadcast_like import broadcast_like_op as broadcast_like
from oneflow.nn.modules.chunk import chunk_op as chunk
from oneflow.nn.modules.constant import ones_op as ones
from oneflow.nn.modules.constant import zeros_op as zeros
from oneflow.nn.modules.constant import full_op as full
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/framework/docstr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@
from .dataset import *
from .bmm import *
from .flatten import *
from .chunk import *
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.ops.array_ops import parse_slice_tuple_list
import oneflow
from oneflow.framework.docstr.utils import add_docstr


@register_tensor_op("chunk")
def chunk_op(input, chunks, dim: int = 0):
add_docstr(
oneflow.chunk,
"""Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor. Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.

Args:
Expand Down Expand Up @@ -59,16 +57,5 @@ def chunk_op(input, chunks, dim: int = 0):
>>> of_out_shape
[(5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 3)]

"""
split_size = input.shape[dim] // chunks
if split_size * chunks != input.shape[dim]:
split_size = [split_size] * (chunks - 1) + [
input.shape[dim] - split_size * (chunks - 1)
]
return flow._C.split(input, split_size=split_size, dim=dim)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
""",
)
8 changes: 7 additions & 1 deletion python/oneflow/framework/docstr/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,14 +337,20 @@
""",
)

add_docstr(
oneflow.Tensor.chunk,
"""
See :func:`oneflow.chunk`
""",
)

add_docstr(
oneflow.Tensor.cast,
"""
See :func:`oneflow.cast`
""",
)


add_docstr(
oneflow.Tensor.diag,
"""
Expand Down
5 changes: 5 additions & 0 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,10 @@ def _bmm(self, other):
return flow.bmm(self, other)


def _chunk(self, chunks=None, dim=None):
return flow.chunk(self, chunks, dim)
BBuf marked this conversation as resolved.
Show resolved Hide resolved


def _all(self, dim=None, keepdim=False):
return flow.all(self, dim, keepdim)

Expand Down Expand Up @@ -869,6 +873,7 @@ def RegisterMethods():
Tensor.logical_not = _not
Tensor.roll = _roll
Tensor.bmm = _bmm
Tensor.chunk = _chunk
Tensor.squeeze = _squeeze
Tensor.unfold = _unfold
Tensor.narrow = _narrow
Expand Down