Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate chunk python layer to functor #6983

Merged
merged 16 commits into from
Dec 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1688,6 +1688,12 @@
]
bind_python: True

- name: "chunk"
signature: [
"TensorTuple (Tensor x, Int64 chunks, Int64 dim=0) => Chunk",
]
bind_python: True

- name: "split_like"
signature: "TensorTuple (Tensor x, TensorTuple like, Int64 axis) => SplitLike"
bind_python: True
Expand Down
47 changes: 47 additions & 0 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,52 @@ class SplitFunctor {
}
};

class ChunkFunctor {
public:
ChunkFunctor() {}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const int64_t& chunks,
const int64_t& dim) const {
int64_t axis = dim;
int64_t split_size = x->shape()->At(dim) / chunks;
int64_t dim_size = x->shape()->At(axis);
if (axis < 0) { axis += x->ndim(); }
CHECK_OR_RETURN(axis >= 0 && axis < x->ndim())
<< "Dimension out of range (expected to be in range of [" << -(x->ndim()) << ", "
<< x->ndim() - 1 << "], but got " << dim;
if ((split_size * chunks) != x->shape()->At(dim)) {
std::vector<int64_t> sections;
for (int i = 0; i < chunks - 1; ++i) { sections.emplace_back(split_size); }
sections.emplace_back(x->shape()->At(dim) - split_size * (chunks - 1));
int64_t num_splits = sections.size();
TensorTuple splits(num_splits);
int64_t start_idx = 0;
for (int i = 0; i < num_splits; ++i) {
int64_t length = sections[i];
CHECK_GE_OR_RETURN(length, 0) << "split_with_sizes expects split_sizes have only "
"non-negative entries, but split_sizes["
<< i << "] = " << length;
splits[i] = JUST(Narrow(x, axis, start_idx, length));
start_idx += length;
}
CHECK_EQ_OR_RETURN(start_idx, dim_size)
<< "split_with_sizes expects split_sizes to sum exactly to " << dim_size
<< " (input tensor's size at dimension " << axis << "), "
<< "but got sum(split_sizes)=" << start_idx;
return splits;
}
CHECK_GE_OR_RETURN(split_size, 0)
<< "split expects split_size be non-negative, but got split_size=" << split_size;
int64_t num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
TensorTuple splits(num_splits);
int64_t last_split_size = split_size - (split_size * num_splits - dim_size);
for (int i = 0; i < num_splits; ++i) {
int64_t length = i < num_splits - 1 ? split_size : last_split_size;
splits[i] = JUST(Narrow(x, axis, i * split_size, length));
}
return splits;
}
};

class SplitLikeFunctor {
public:
SplitLikeFunctor() {
Expand Down Expand Up @@ -2326,6 +2372,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::ReduceSumLikeFunctor>("ReduceSumLike");
m.add_functor<impl::BroadcastReduceSumLikeFunctor>("BroadcastReduceSumLike");
m.add_functor<impl::SplitFunctor>("Split");
m.add_functor<impl::ChunkFunctor>("Chunk");
m.add_functor<impl::SplitLikeFunctor>("SplitLike");
m.add_functor<impl::SplitWithSizeFunctor>("SplitWithSize");
m.add_functor<impl::BatchGatherFunctor>("BatchGather");
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def is_deprecated(func_or_class):
from oneflow._C import atanh
from oneflow._C import atanh as arctanh
from oneflow._C import batch_matmul as bmm
from oneflow._C import chunk
from oneflow._C import sign
from oneflow._C import sinh
from oneflow._C import tan
Expand Down Expand Up @@ -284,7 +285,6 @@ def atexit_hook(hook):
from oneflow.nn.modules.argsort import argsort_op as argsort
from oneflow.nn.modules.argwhere import argwhere_op as argwhere
from oneflow.nn.modules.broadcast_like import broadcast_like_op as broadcast_like
from oneflow.nn.modules.chunk import chunk_op as chunk
from oneflow.nn.modules.constant import ones_op as ones
from oneflow.nn.modules.constant import zeros_op as zeros
from oneflow.nn.modules.constant import full_op as full
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/framework/docstr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@
from .dataset import *
from .bmm import *
from .flatten import *
from .chunk import *
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.ops.array_ops import parse_slice_tuple_list
import oneflow
from oneflow.framework.docstr.utils import add_docstr


@register_tensor_op("chunk")
def chunk_op(input, chunks, dim: int = 0):
add_docstr(
oneflow.chunk,
"""Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor. Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.

Args:
Expand Down Expand Up @@ -59,16 +57,5 @@ def chunk_op(input, chunks, dim: int = 0):
>>> of_out_shape
[(5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 3)]

"""
split_size = input.shape[dim] // chunks
if split_size * chunks != input.shape[dim]:
split_size = [split_size] * (chunks - 1) + [
input.shape[dim] - split_size * (chunks - 1)
]
return flow._C.split(input, split_size=split_size, dim=dim)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
""",
)
8 changes: 7 additions & 1 deletion python/oneflow/framework/docstr/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,14 +337,20 @@
""",
)

add_docstr(
oneflow.Tensor.chunk,
"""
See :func:`oneflow.chunk`
""",
)

add_docstr(
oneflow.Tensor.cast,
"""
See :func:`oneflow.cast`
""",
)


add_docstr(
oneflow.Tensor.diag,
"""
Expand Down
5 changes: 5 additions & 0 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,10 @@ def _bmm(self, other):
return flow.bmm(self, other)


def _chunk(self, chunks=None, dim=None):
return flow._C.chunk(self, chunks, dim)


def _all(self, dim=None, keepdim=False):
return flow.all(self, dim, keepdim)

Expand Down Expand Up @@ -869,6 +873,7 @@ def RegisterMethods():
Tensor.logical_not = _not
Tensor.roll = _roll
Tensor.bmm = _bmm
Tensor.chunk = _chunk
Tensor.squeeze = _squeeze
Tensor.unfold = _unfold
Tensor.narrow = _narrow
Expand Down