Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat lazy tensor indexing #9334

Merged
merged 32 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6d40284
feat(boxing): collective_boxing slice_boxing support 0size tensor
wyg1997 Oct 10, 2022
e1b02ae
test(Indexing): add lazy tensor basic indexing
wyg1997 Oct 11, 2022
8272026
add MaskTensor judgement
wyg1997 Oct 12, 2022
e636732
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 Oct 17, 2022
3abb999
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 Oct 28, 2022
e8e441b
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 Nov 1, 2022
dcb2065
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 Nov 1, 2022
a52eddb
format code
wyg1997 Nov 1, 2022
e03a98c
feat(TensorIndexing): support lazy advance getitem indexing
wyg1997 Nov 2, 2022
a4fab01
feat(Indexing): support lazy indexing for lazy_tensor and free_tensor
wyg1997 Nov 3, 2022
d15ed76
fix(Indexing): fix indexing test bug
wyg1997 Nov 3, 2022
4890857
test(Indexing): test all advance indexing
wyg1997 Nov 4, 2022
17fd5ff
test(GlobalIndexing): fix eager global indexing bug
wyg1997 Nov 4, 2022
01bbd25
test(Indexing): support combined indexing
wyg1997 Nov 4, 2022
8f95248
add last test cases
wyg1997 Nov 4, 2022
27c6961
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 Nov 4, 2022
1cb928e
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 Nov 7, 2022
3b0be93
fix merge bug
wyg1997 Nov 7, 2022
dd4b56a
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 Nov 10, 2022
acf65d0
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 Nov 14, 2022
a71de68
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 Nov 15, 2022
d2634c8
fix lazy mode guard
wyg1997 Nov 15, 2022
6ec3ec4
Merge remote-tracking branch 'origin/master' into feat-lazy_tensor_in…
wyg1997 Nov 15, 2022
7eff650
test(Indexing): refine set scalar value test
wyg1997 Nov 15, 2022
37c5adc
test(Indexing): enable all bool tensor index setitem
wyg1997 Nov 15, 2022
4affec4
decrease test time
wyg1997 Nov 15, 2022
fdf4f1f
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 Nov 21, 2022
baccecd
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 Nov 21, 2022
ed18408
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 Nov 22, 2022
60f2549
Merge branch 'master' into feat-lazy_tensor_indexing
wyg1997 Nov 23, 2022
40f30d5
refine 0size shape judgement
wyg1997 Nov 23, 2022
8c1785d
add comment
wyg1997 Nov 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 35 additions & 30 deletions oneflow/api/python/framework/tensor_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,8 +783,6 @@ static PyObject* PyTensorObject_to_local(PyObject* self, PyObject* unused, PyObj

int PyTensorObject_setitem(PyObject* self, PyObject* item, PyObject* value) {
HANDLE_ERRORS
auto tensor = PyTensor_Unpack(self);
std::shared_ptr<Tensor> value_tensor;
CHECK_OR_THROW(functional::PyTensorIndexCheck(item))
<< Error::TypeError() << "tensor_setitem(): argument 'index' must be index, not "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(item)));
Expand All @@ -793,6 +791,7 @@ int PyTensorObject_setitem(PyObject* self, PyObject* item, PyObject* value) {
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(value)));
const auto& index_item = functional::PyUnpackTensorIndex(item);

auto tensor = PyTensor_Unpack(self);
// NOTE: use masked_fill_(local,global) to avoid D2H in TensorSetItem if index is bool tensor
if (functional::PyScalarCheck(value) && index_item.size() == 1 && index_item[0].IsTensor()) {
const auto& index_tensor = index_item[0].tensor();
Expand All @@ -804,35 +803,41 @@ int PyTensorObject_setitem(PyObject* self, PyObject* item, PyObject* value) {
}
}

if (tensor->is_global()) {
Symbol<ParallelDesc> placement = ASSERT(tensor->parallel_desc());
auto ndsbp = ASSERT(tensor->nd_sbp());
std::vector<Symbol<SbpParallel>> sbp(ndsbp->sbp_parallel_size(),
ASSERT(MakeBroadcastSbpParallel()));
if (functional::PyScalarCheck(value)) {
Scalar value_scalar = functional::PyUnpackScalar(value);
value_tensor = ASSERT_PTR(
functional::GlobalConstant(Shape({}), value_scalar, tensor->dtype(), placement, sbp));
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_global())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a global tensor when self is global";
value_tensor =
ASSERT_PTR(functional::ToGlobal(value_tensor, placement, sbp, {}, true, /*copy=*/false));
}
} else {
if (functional::PyScalarCheck(value)) {
Scalar value_scalar = functional::PyUnpackScalar(value);
value_tensor = ASSERT_PTR(
functional::Constant(Shape({}), value_scalar, tensor->dtype(), ASSERT(tensor->device())));
std::shared_ptr<Tensor> value_tensor;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里代码太长,只是加了一个作用域

{
if (tensor->is_global()) {
Symbol<ParallelDesc> placement = ASSERT(tensor->parallel_desc());
auto ndsbp = ASSERT(tensor->nd_sbp());
std::vector<Symbol<SbpParallel>> sbp(ndsbp->sbp_parallel_size(),
ASSERT(MakeBroadcastSbpParallel()));
if (functional::PyScalarCheck(value)) {
Scalar value_scalar = functional::PyUnpackScalar(value);
value_tensor = ASSERT_PTR(
functional::GlobalConstant(Shape({}), value_scalar, tensor->dtype(), placement, sbp));
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_global())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a global tensor when self is global";
value_tensor = ASSERT_PTR(
functional::ToGlobal(value_tensor, placement, sbp, {}, true, /*copy=*/false));
}
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_local())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a local tensor when self is local";
Optional<Symbol<Device>> device = ASSERT(tensor->device());
value_tensor = ASSERT_PTR(functional::To(value_tensor, device, value_tensor->dtype(), false));
if (functional::PyScalarCheck(value)) {
// NOTE: initialize value_tensor in eager mode
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled=*/false);
Scalar value_scalar = functional::PyUnpackScalar(value);
value_tensor = ASSERT_PTR(functional::Constant(Shape({}), value_scalar, tensor->dtype(),
ASSERT(tensor->device())));
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_local())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a local tensor when self is local";
Optional<Symbol<Device>> device = ASSERT(tensor->device());
value_tensor =
ASSERT_PTR(functional::To(value_tensor, device, value_tensor->dtype(), false));
}
}
}
ASSERT(functional::TensorSetItem(tensor, index_item, value_tensor));
Expand Down
2 changes: 2 additions & 0 deletions oneflow/api/python/functional/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ Shape InferArraySizes(PyObject* object) {
}

Maybe<Tensor> ConvertToIndexingTensor(PyObject* object) {
// NOTE: convert data to indexing will ensure in eager mode
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);
const DataType dtype = InferScalarType(object);
const auto& device = JUST(Device::New("cpu"));

Expand Down
9 changes: 8 additions & 1 deletion oneflow/core/functional/tensor_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int64_t CountSpecifiedDims(const TensorIndex& index) {
specified_ndims++;
} else if (index_item.IsTensor()) {
const auto& tensor = index_item.tensor();
if (tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8()) {
if (IsMaskTensor(tensor)) {
specified_ndims += tensor->ndim();
} else {
specified_ndims++;
Expand Down Expand Up @@ -247,6 +247,11 @@ Maybe<Tensor> PermuteBackForGlobalTensor(const std::shared_ptr<Tensor>& result,

} // namespace

bool IsMaskTensor(const std::shared_ptr<Tensor>& tensor) {
return tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8()
|| tensor->dtype() == DType::Bool();
}

Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,
std::vector<detail::Slice>* slice_indices,
TensorTuple* tensor_indices, std::vector<int64_t>* expand_dims,
Expand Down Expand Up @@ -558,6 +563,8 @@ Maybe<void> UnifyInputAndIndicesOnDevice(const std::shared_ptr<Tensor>& x,
const auto tensor_index = tensor_indices[i];
if (tensor_index == nullptr) { continue; }
if (tensor_index->is_local()) {
// NOTE: LocalToGlobal should be called in eager mode
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);
Comment on lines +566 to +567
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LocalToGlobal 只能在 eager 模式下调用

tensor_indices[i] = JUST(ToGlobal(tensor_index, placement,
std::vector<Symbol<SbpParallel>>(n, broadcast_sbp),
grad_sbp_tuple, /*check_meta=*/false, /*copy=*/false));
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/functional/tensor_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class TensorIndex : public std::vector<detail::IndexItem> {
using std::vector<detail::IndexItem>::vector;
};

bool IsMaskTensor(const std::shared_ptr<Tensor>& tensor);

Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,
std::vector<detail::Slice>* slice_indices,
TensorTuple* tensor_indices, std::vector<int64_t>* expand_dims,
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/kernel/collective_boxing_unpack_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ void CollectiveBoxingUnpackKernel::ForwardDataContent(KernelContext* ctx) const
const CollectiveBoxingUnpackOpConf& unpack_conf = this->op_conf().collective_boxing_unpack_conf();
const int64_t num_ranks = unpack_conf.num_ranks();
const Shape logical_shape(unpack_conf.logical_shape());
// skip 0size tensor boxing
if (logical_shape.elem_cnt() == 0) { return; }
const bool need_transpose = !((unpack_conf.src_sbp_parallel().has_split_parallel()
&& unpack_conf.src_sbp_parallel().split_parallel().axis() == 0)
|| unpack_conf.src_sbp_parallel().has_broadcast_parallel()
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/kernel/slice_boxing_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ class SliceBoxingAddKernel final : public SliceBoxingKernel {

void SliceBoxingKernel::VirtualKernelInit(KernelContext* ctx) {
const SliceBoxingConf& conf = GetCustomizedBoxingConf();
if (std::accumulate(conf.out_shape().dim().begin(), conf.out_shape().dim().end(), 1,
std::multiplies<int64_t>())
== 0) {
return;
}
const TensorSliceView out_slice(conf.out_slice());
for (const TensorSliceViewProto& in_slice_proto : conf.in_slice()) {
const TensorSliceView in_slice(in_slice_proto);
Expand All @@ -82,6 +87,7 @@ const SliceBoxingConf& SliceBoxingCopyKernel::GetCustomizedBoxingConf() const {

void SliceBoxingCopyKernel::ForwardDataContent(KernelContext* ctx) const {
Blob* out = ctx->BnInOp2Blob("out");
if (out->shape_view().elem_cnt() == 0) { return; }
FOR_RANGE(int64_t, i, 0, this->op_attribute().input_bns().size()) {
const Blob* in_i = ctx->BnInOp2Blob(GenRepeatedBn("in", i));
this->tensor_slice_copier_vec().at(i)->Copy(ctx->stream(), out, in_i);
Expand All @@ -94,6 +100,7 @@ const SliceBoxingConf& SliceBoxingAddKernel::GetCustomizedBoxingConf() const {

void SliceBoxingAddKernel::ForwardDataContent(KernelContext* ctx) const {
Blob* out = ctx->BnInOp2Blob("out");
if (out->shape_view().elem_cnt() == 0) { return; }
std::unique_ptr<ep::primitive::Add> primitive =
ep::primitive::NewPrimitive<ep::primitive::AddFactory>(ctx->stream()->device_type(),
out->data_type());
Expand Down
6 changes: 5 additions & 1 deletion oneflow/user/ops/slice_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) {
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> SliceUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {
return InferLogicalTensorDesc(ctx);
const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0);
auto* y_desc = ctx->MutOutputTensorDesc("y", 0);
y_desc->set_shape(ref_desc.shape());
y_desc->set_is_dynamic(ref_desc.is_dynamic());
return Maybe<void>::Ok();
}
Comment on lines 106 to 112
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前 SliceUpdate 的物理 Tensor 推导是错误的,它支持 S + B -> S,是不能和逻辑 shape 推导共用推导函数(逻辑推导函数中有一些 shape 的检察,物理 tensor shape 推导不需要)

/*static*/ Maybe<void> SliceUpdateOp::InferDataType(user_op::InferContext* ctx) {
const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0);
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/test/tensor/test_global_tensor_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _test_advanced_indexing(test_case, placement, dtype):

# pick a random valid indexer type
def ri(indices):
choice = _randint(0, 2)
choice = _randint(0, 3)
if choice == 0:
return _cpu_global_tensor(flow.LongTensor(indices)).to_global(
placement, broadcast_for_placement
Expand Down
Loading