Skip to content

Commit

Permalink
Feat lazy tensor indexing (#9334)
Browse files Browse the repository at this point in the history
* feat(boxing): collective_boxing slice_boxing support 0size tensor

* test(Indexing): add lazy tensor basic indexing

* add MaskTensor judgement

* format code

* feat(TensorIndexing): support lazy advance getitem indexing

* feat(Indexing): support lazy indexing for lazy_tensor and free_tensor

* fix(Indexing): fix indexing test bug

* test(Indexing): test all advance indexing

* test(GlobalIndexing): fix eager global indexing bug

* test(Indexing): support combined indexing

* add last test cases

* fix merge bug

* fix lazy mode guard

* test(Indexing): refine set scalar value test

* test(Indexing): enable all bool tensor index setitem

* decrease test time

* refine 0size shape judgement

* add comment
  • Loading branch information
wyg1997 authored Nov 23, 2022
1 parent e4a80d8 commit 0434698
Show file tree
Hide file tree
Showing 9 changed files with 1,010 additions and 33 deletions.
65 changes: 35 additions & 30 deletions oneflow/api/python/framework/tensor_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,6 @@ static PyObject* PyTensorObject_to_local(PyObject* self, PyObject* unused, PyObj

int PyTensorObject_setitem(PyObject* self, PyObject* item, PyObject* value) {
HANDLE_ERRORS
auto tensor = PyTensor_Unpack(self);
std::shared_ptr<Tensor> value_tensor;
CHECK_OR_THROW(functional::PyTensorIndexCheck(item))
<< Error::TypeError() << "tensor_setitem(): argument 'index' must be index, not "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(item)));
Expand All @@ -794,6 +792,7 @@ int PyTensorObject_setitem(PyObject* self, PyObject* item, PyObject* value) {
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(value)));
const auto& index_item = functional::PyUnpackTensorIndex(item);

auto tensor = PyTensor_Unpack(self);
// NOTE: use masked_fill_(local,global) to avoid D2H in TensorSetItem if index is bool tensor
if (functional::PyScalarCheck(value) && index_item.size() == 1 && index_item[0].IsTensor()) {
const auto& index_tensor = index_item[0].tensor();
Expand All @@ -805,35 +804,41 @@ int PyTensorObject_setitem(PyObject* self, PyObject* item, PyObject* value) {
}
}

if (tensor->is_global()) {
Symbol<ParallelDesc> placement = ASSERT(tensor->parallel_desc());
auto ndsbp = ASSERT(tensor->nd_sbp());
std::vector<Symbol<SbpParallel>> sbp(ndsbp->sbp_parallel_size(),
ASSERT(MakeBroadcastSbpParallel()));
if (functional::PyScalarCheck(value)) {
Scalar value_scalar = functional::PyUnpackScalar(value);
value_tensor = ASSERT_PTR(
functional::GlobalConstant(Shape({}), value_scalar, tensor->dtype(), placement, sbp));
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_global())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a global tensor when self is global";
value_tensor =
ASSERT_PTR(functional::ToGlobal(value_tensor, placement, sbp, {}, true, /*copy=*/false));
}
} else {
if (functional::PyScalarCheck(value)) {
Scalar value_scalar = functional::PyUnpackScalar(value);
value_tensor = ASSERT_PTR(
functional::Constant(Shape({}), value_scalar, tensor->dtype(), ASSERT(tensor->device())));
std::shared_ptr<Tensor> value_tensor;
{
if (tensor->is_global()) {
Symbol<ParallelDesc> placement = ASSERT(tensor->parallel_desc());
auto ndsbp = ASSERT(tensor->nd_sbp());
std::vector<Symbol<SbpParallel>> sbp(ndsbp->sbp_parallel_size(),
ASSERT(MakeBroadcastSbpParallel()));
if (functional::PyScalarCheck(value)) {
Scalar value_scalar = functional::PyUnpackScalar(value);
value_tensor = ASSERT_PTR(
functional::GlobalConstant(Shape({}), value_scalar, tensor->dtype(), placement, sbp));
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_global())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a global tensor when self is global";
value_tensor = ASSERT_PTR(
functional::ToGlobal(value_tensor, placement, sbp, {}, true, /*copy=*/false));
}
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_local())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a local tensor when self is local";
Optional<Symbol<Device>> device = ASSERT(tensor->device());
value_tensor = ASSERT_PTR(functional::To(value_tensor, device, value_tensor->dtype(), false));
if (functional::PyScalarCheck(value)) {
// NOTE: initialize value_tensor in eager mode
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled=*/false);
Scalar value_scalar = functional::PyUnpackScalar(value);
value_tensor = ASSERT_PTR(functional::Constant(Shape({}), value_scalar, tensor->dtype(),
ASSERT(tensor->device())));
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_local())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a local tensor when self is local";
Optional<Symbol<Device>> device = ASSERT(tensor->device());
value_tensor =
ASSERT_PTR(functional::To(value_tensor, device, value_tensor->dtype(), false));
}
}
}
ASSERT(functional::TensorSetItem(tensor, index_item, value_tensor));
Expand Down
2 changes: 2 additions & 0 deletions oneflow/api/python/functional/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ Shape InferArraySizes(PyObject* object) {
}

Maybe<Tensor> ConvertToIndexingTensor(PyObject* object) {
// NOTE: convert data to indexing will ensure in eager mode
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);
const DataType dtype = InferScalarType(object);
const auto& device = JUST(Device::New("cpu"));

Expand Down
9 changes: 8 additions & 1 deletion oneflow/core/functional/tensor_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int64_t CountSpecifiedDims(const TensorIndex& index) {
specified_ndims++;
} else if (index_item.IsTensor()) {
const auto& tensor = index_item.tensor();
if (tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8()) {
if (IsMaskTensor(tensor)) {
specified_ndims += tensor->ndim();
} else {
specified_ndims++;
Expand Down Expand Up @@ -247,6 +247,11 @@ Maybe<Tensor> PermuteBackForGlobalTensor(const std::shared_ptr<Tensor>& result,

} // namespace

bool IsMaskTensor(const std::shared_ptr<Tensor>& tensor) {
return tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8()
|| tensor->dtype() == DType::Bool();
}

Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,
std::vector<detail::Slice>* slice_indices,
TensorTuple* tensor_indices, std::vector<int64_t>* expand_dims,
Expand Down Expand Up @@ -558,6 +563,8 @@ Maybe<void> UnifyInputAndIndicesOnDevice(const std::shared_ptr<Tensor>& x,
const auto tensor_index = tensor_indices[i];
if (tensor_index == nullptr) { continue; }
if (tensor_index->is_local()) {
// NOTE: LocalToGlobal should be called in eager mode
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);
tensor_indices[i] = JUST(ToGlobal(tensor_index, placement,
std::vector<Symbol<SbpParallel>>(n, broadcast_sbp),
grad_sbp_tuple, /*check_meta=*/false, /*copy=*/false));
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/functional/tensor_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class TensorIndex : public std::vector<detail::IndexItem> {
using std::vector<detail::IndexItem>::vector;
};

bool IsMaskTensor(const std::shared_ptr<Tensor>& tensor);

Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,
std::vector<detail::Slice>* slice_indices,
TensorTuple* tensor_indices, std::vector<int64_t>* expand_dims,
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/kernel/collective_boxing_unpack_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ void CollectiveBoxingUnpackKernel::ForwardDataContent(KernelContext* ctx) const
const CollectiveBoxingUnpackOpConf& unpack_conf = this->op_conf().collective_boxing_unpack_conf();
const int64_t num_ranks = unpack_conf.num_ranks();
const Shape logical_shape(unpack_conf.logical_shape());
// skip 0size tensor boxing
if (logical_shape.elem_cnt() == 0) { return; }
const bool need_transpose = !((unpack_conf.src_sbp_parallel().has_split_parallel()
&& unpack_conf.src_sbp_parallel().split_parallel().axis() == 0)
|| unpack_conf.src_sbp_parallel().has_broadcast_parallel()
Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/kernel/slice_boxing_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class SliceBoxingAddKernel final : public SliceBoxingKernel {

void SliceBoxingKernel::VirtualKernelInit(KernelContext* ctx) {
const SliceBoxingConf& conf = GetCustomizedBoxingConf();
if (/*is_0size_tensor=*/std::any_of(conf.out_shape().dim().begin(), conf.out_shape().dim().end(),
[](int64_t dim) { return dim == 0; })) {
return;
}
const TensorSliceView out_slice(conf.out_slice());
for (const TensorSliceViewProto& in_slice_proto : conf.in_slice()) {
const TensorSliceView in_slice(in_slice_proto);
Expand All @@ -82,6 +86,7 @@ const SliceBoxingConf& SliceBoxingCopyKernel::GetCustomizedBoxingConf() const {

void SliceBoxingCopyKernel::ForwardDataContent(KernelContext* ctx) const {
Blob* out = ctx->BnInOp2Blob("out");
if (out->shape_view().elem_cnt() == 0) { return; }
FOR_RANGE(int64_t, i, 0, this->op_attribute().input_bns().size()) {
const Blob* in_i = ctx->BnInOp2Blob(GenRepeatedBn("in", i));
this->tensor_slice_copier_vec().at(i)->Copy(ctx->stream(), out, in_i);
Expand All @@ -94,6 +99,7 @@ const SliceBoxingConf& SliceBoxingAddKernel::GetCustomizedBoxingConf() const {

void SliceBoxingAddKernel::ForwardDataContent(KernelContext* ctx) const {
Blob* out = ctx->BnInOp2Blob("out");
if (out->shape_view().elem_cnt() == 0) { return; }
std::unique_ptr<ep::primitive::Add> primitive =
ep::primitive::NewPrimitive<ep::primitive::AddFactory>(ctx->stream()->device_type(),
out->data_type());
Expand Down
6 changes: 5 additions & 1 deletion oneflow/user/ops/slice_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) {
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> SliceUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {
return InferLogicalTensorDesc(ctx);
const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0);
auto* y_desc = ctx->MutOutputTensorDesc("y", 0);
y_desc->set_shape(ref_desc.shape());
y_desc->set_is_dynamic(ref_desc.is_dynamic());
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> SliceUpdateOp::InferDataType(user_op::InferContext* ctx) {
const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0);
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/test/tensor/test_global_tensor_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _test_advanced_indexing(test_case, placement, dtype):

# pick a random valid indexer type
def ri(indices):
choice = _randint(0, 2)
choice = _randint(0, 3)
if choice == 0:
return _cpu_global_tensor(flow.LongTensor(indices)).to_global(
placement, broadcast_for_placement
Expand Down
Loading

0 comments on commit 0434698

Please sign in to comment.