Skip to content

Commit 3beb3b3

Browse files
[Slice] Dispatch bool-only combined case to non-zero and elementwise_get (#74320)
* update slice-check * pass * Refine and Add ut * Rerun CI * Fix multi-dim bool tensor index * Fix strided=0 case --------- Co-authored-by: zhanghonggeng <zhanghonggeng@baidu.com>
1 parent 7db3dff commit 3beb3b3

File tree

5 files changed

+109
-56
lines changed

5 files changed

+109
-56
lines changed

paddle/fluid/pybind/eager_method.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,19 +1623,10 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
16231623
&trans_dim,
16241624
&out_is_view);
16251625

1626-
bool has_bool_index = false;
1627-
for (auto& index : transed_index) {
1628-
if (index.dtype() == phi::DataType::BOOL) {
1629-
has_bool_index = true;
1630-
}
1631-
}
16321626
const int index_size = PyTuple_GET_SIZE(index_ptr);
1633-
const bool is_combined_bool = has_bool_index && index_size > 1;
1634-
16351627
ApplyGetitem(index_size,
16361628
pos_of_new_dim,
16371629
rank_of_new_dim,
1638-
is_combined_bool,
16391630
&transed_index,
16401631
&tensor,
16411632
&self->tensor,

paddle/fluid/pybind/slice_utils.h

Lines changed: 94 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <Python.h>
1818

1919
#include <algorithm>
20+
#include <cstdint>
2021
#include "paddle/fluid/eager/api/all.h"
2122
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
2223
#include "paddle/fluid/eager/utils.h"
@@ -30,6 +31,7 @@
3031
#include "paddle/phi/kernels/funcs/common_infer_shape_functions.h"
3132
#include "paddle/phi/kernels/funcs/slice_utils.h"
3233
#include "paddle/phi/kernels/funcs/strided_slice.h"
34+
#include "paddle/utils/pybind.h"
3335
#include "pybind11/numpy.h"
3436
#include "pybind11/pybind11.h"
3537
#include "pybind11/stl.h"
@@ -519,21 +521,31 @@ static void ParseIndex(const paddle::Tensor& tensor,
519521
estimated_dim++;
520522
}
521523
} else {
524+
*has_advanced_index = true;
522525
if (slice_tensor.dtype() == phi::DataType::BOOL) {
523-
PADDLE_ENFORCE_EQ(slice_tensor.shape()[0],
524-
dim_len,
525-
common::errors::OutOfRange(
526-
"The shape of boolean index %d did not match"
527-
"indexed tensor %d along axis %d.",
528-
slice_tensor.shape()[0],
529-
dim_len,
530-
current_dim));
526+
// bool tensor consumes (rank of index tensor) dimensions of input
527+
// tensor
528+
for (int i = 0; i < slice_tensor.shape().size(); i++) {
529+
PADDLE_ENFORCE_EQ(slice_tensor.shape()[i],
530+
dim_len,
531+
common::errors::OutOfRange(
532+
"The shape of boolean index %d did not match"
533+
"indexed tensor %d along axis %d.",
534+
slice_tensor.shape()[0],
535+
dim_len,
536+
current_dim));
537+
(*advanced_index_dim)[estimated_dim] = estimated_dim;
538+
estimated_dim++;
539+
current_dim++;
540+
dim_len = shape[current_dim];
541+
}
542+
} else {
543+
// int tensor consumes only one dimension of input tensor
544+
(*advanced_index_dim)[estimated_dim] = estimated_dim;
545+
estimated_dim++;
546+
current_dim++;
531547
}
532-
*has_advanced_index = true;
533548
advanced_index->push_back(std::move(slice_tensor));
534-
(*advanced_index_dim)[estimated_dim] = estimated_dim;
535-
estimated_dim++;
536-
current_dim++;
537549
}
538550

539551
} else {
@@ -648,17 +660,14 @@ static paddle::Tensor dealWithAdvancedIndex(
648660
int* rank_of_new_dim,
649661
std::vector<int>* trans_dim,
650662
bool* out_is_view) {
663+
*rank_of_new_dim = 0;
651664
int p = 0;
652-
bool int_tensor_only = true;
653665
for (size_t i = 0; i < advanced_index_dim->size(); ++i) {
654666
auto index_dim = (*advanced_index_dim)[i];
655667
if (index_dim != -1) {
656-
// size of advanced_index is same to number of non -1 element in
657-
// advanced_index_dim
668+
// sum of each advanced_index_tensor's rank equals to number of non -1
669+
// element in advanced_index_dim
658670
auto index = (*advanced_index)[p++];
659-
if (index.dtype() == phi::DataType::BOOL) {
660-
int_tensor_only = false;
661-
}
662671

663672
if (index_dim == 0) {
664673
// case 1: advanced indices at axis 0, the new dim will be at first.
@@ -671,11 +680,23 @@ static paddle::Tensor dealWithAdvancedIndex(
671680
} else {
672681
*pos_of_new_dim = std::min(index_dim, *pos_of_new_dim);
673682
}
674-
*rank_of_new_dim =
675-
std::max(*rank_of_new_dim, static_cast<int>(index.shape().size()));
676683

677-
trans_dim->push_back(index_dim);
678-
transed_index->push_back(std::move(index));
684+
if (index.dtype() == phi::DataType::BOOL) {
685+
*rank_of_new_dim = std::max(*rank_of_new_dim, 1);
686+
i--;
687+
for (int j = 0; j < index.shape().size(); j++) {
688+
i++;
689+
index_dim = (*advanced_index_dim)[i];
690+
trans_dim->push_back(index_dim);
691+
}
692+
transed_index->push_back(std::move(index));
693+
} else {
694+
*rank_of_new_dim =
695+
std::max(*rank_of_new_dim, static_cast<int>(index.shape().size()));
696+
697+
trans_dim->push_back(index_dim);
698+
transed_index->push_back(std::move(index));
699+
}
679700
}
680701
}
681702

@@ -695,8 +716,7 @@ static paddle::Tensor dealWithAdvancedIndex(
695716
transed_tensor = tensor;
696717
} else {
697718
*out_is_view = true;
698-
if (FLAGS_use_stride_kernel && *pos_of_new_dim != 0 &&
699-
(is_for_setitem || int_tensor_only)) {
719+
if (FLAGS_use_stride_kernel && *pos_of_new_dim != 0) {
700720
transed_tensor = tensor;
701721
} else {
702722
transed_tensor = transpose_ad_func(tensor, *trans_dim);
@@ -731,9 +751,10 @@ static std::vector<paddle::Tensor> PrepareIndices(
731751
}
732752

733753
static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
754+
const paddle::Tensor& self_tensor,
734755
const paddle::Tensor& bool_index,
735756
const int64_t slice_offset,
736-
const bool is_combined_bool) {
757+
const int64_t pos_of_new_dim) {
737758
PADDLE_ENFORCE(bool_index.shape().size() <= tensor.shape().size(),
738759
common::errors::InvalidArgument(
739760
"The dims of bool index doesn't match indexed array, "
@@ -743,34 +764,52 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
743764
bool_index.shape().size()));
744765
auto tensor_shape = tensor.shape();
745766
size_t i = 0;
746-
while (i < bool_index.shape().size()) {
747-
PADDLE_ENFORCE_EQ(
748-
bool_index.shape()[i],
749-
tensor_shape[i],
750-
common::errors::OutOfRange(
751-
"The dimension of bool index doesn't match indexed array along "
752-
"dimension %d, the target dimension is %d, but received %d",
753-
i,
754-
tensor_shape[i],
755-
bool_index.shape()[i]));
756-
i++;
767+
if (FLAGS_use_stride_kernel) {
768+
while (i < bool_index.shape().size()) {
769+
PADDLE_ENFORCE_EQ(
770+
bool_index.shape()[i],
771+
tensor_shape[i + pos_of_new_dim],
772+
common::errors::OutOfRange(
773+
"The dimension of bool index doesn't match indexed array along "
774+
"dimension %d, the target dimension is %d, but received %d",
775+
i,
776+
tensor_shape[i + pos_of_new_dim],
777+
bool_index.shape()[i]));
778+
i++;
779+
}
780+
} else {
781+
while (i < bool_index.shape().size()) {
782+
PADDLE_ENFORCE_EQ(
783+
bool_index.shape()[i],
784+
tensor_shape[i],
785+
common::errors::OutOfRange(
786+
"The dimension of bool index doesn't match indexed array along "
787+
"dimension %d, the target dimension is %d, but received %d",
788+
i,
789+
tensor_shape[i],
790+
bool_index.shape()[i]));
791+
i++;
792+
}
757793
}
758794

759795
const phi::distributed::ProcessMesh* mesh = nullptr;
760-
if (InputsContainDistTensor(&mesh, tensor, bool_index)) {
761-
ConvertAllInputsToDistTensor(mesh, tensor, bool_index);
796+
if (InputsContainDistTensor(&mesh, tensor, self_tensor, bool_index)) {
797+
ConvertAllInputsToDistTensor(mesh, tensor, self_tensor, bool_index);
762798
}
763799

764800
if (bool_index.shape().size() == tensor_shape.size()) {
765801
return masked_select_ad_func(tensor, bool_index);
766802
}
767803

768804
auto bool_2_idx = nonzero_ad_func(bool_index);
769-
if (FLAGS_use_stride_kernel && !is_combined_bool) {
805+
if (FLAGS_use_stride_kernel) {
770806
std::vector<paddle::Tensor> indices =
771807
PrepareIndices(tensor, bool_2_idx, bool_index);
808+
for (int i = 0; i < pos_of_new_dim; ++i) {
809+
indices.insert(indices.begin(), paddle::Tensor());
810+
}
772811
while (indices.size() < static_cast<size_t>(tensor.dims().size())) {
773-
indices.emplace_back();
812+
indices.emplace_back(paddle::Tensor());
774813
}
775814

776815
std::vector<paddle::Tensor> indices_int64;
@@ -784,7 +823,7 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
784823
AdvancedIndex ad = AdvancedIndex(tensor, indices_int64);
785824
const bool accumulate = false;
786825

787-
return index_elementwise_get_ad_func(tensor,
826+
return index_elementwise_get_ad_func(self_tensor,
788827
ad.indices,
789828
ad.src_sizes,
790829
ad.src_strides,
@@ -1173,7 +1212,6 @@ static void ApplySetitem(const std::vector<int> trans_dim,
11731212
static void ApplyGetitem(const int index_size,
11741213
const int pos_of_new_dim,
11751214
const int rank_of_new_dim,
1176-
const bool is_combined_bool,
11771215
std::vector<paddle::Tensor>* transed_index,
11781216
paddle::Tensor* tensor,
11791217
paddle::Tensor* self_tensor,
@@ -1202,9 +1240,18 @@ static void ApplyGetitem(const int index_size,
12021240
if (transed_index->size() == 1 &&
12031241
(*transed_index)[0].dtype() == phi::DataType::BOOL) {
12041242
// get value for bool tensor
1205-
int64_t slice_offset = 0;
1206-
*out = getValueForBoolTensor(
1207-
*transed_tensor, (*transed_index)[0], slice_offset, is_combined_bool);
1243+
const int64_t slice_offset =
1244+
reinterpret_cast<const char*>(transed_tensor->data()) -
1245+
reinterpret_cast<const char*>(self_tensor->data());
1246+
*out = getValueForBoolTensor(*transed_tensor,
1247+
(*self_tensor),
1248+
(*transed_index)[0],
1249+
slice_offset,
1250+
pos_of_new_dim);
1251+
if (!FLAGS_use_stride_kernel) {
1252+
handle_transpose(*out);
1253+
}
1254+
return;
12081255
} else {
12091256
// get value for int tensor
12101257
ParseBoolAndBroadcastIndices(transed_index);
@@ -1216,14 +1263,15 @@ static void ApplyGetitem(const int index_size,
12161263
}
12171264
}
12181265

1219-
if (FLAGS_use_stride_kernel && !is_combined_bool && !has_empty_index) {
1266+
if (FLAGS_use_stride_kernel && !has_empty_index) {
12201267
const phi::distributed::ProcessMesh* mesh = nullptr;
12211268
if (InputsContainDistTensor(
12221269
&mesh, *self_tensor, *transed_tensor, *transed_index)) {
12231270
ConvertAllInputsToDistTensor(
12241271
mesh, *self_tensor, *transed_tensor, *transed_index);
12251272
}
12261273

1274+
*transed_index = expandTensors(*transed_index);
12271275
*transed_index = expand_outplace(*transed_index);
12281276

12291277
std::vector<paddle::Tensor> transed_index_int64;
@@ -1278,7 +1326,6 @@ static void ApplyGetitem(const int index_size,
12781326
return;
12791327
}
12801328
}
1281-
12821329
handle_transpose(*out);
12831330
}
12841331

paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad,
2828
bool,
2929
int64_t,
3030
int16_t,
31+
int8_t,
3132
int,
3233
phi::dtype::float16,
3334
phi::dtype::bfloat16,

paddle/phi/kernels/strided_slice_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ PD_REGISTER_KERNEL(strided_slice_grad,
4949
phi::StridedSliceGradKernel,
5050
bool,
5151
int,
52+
int8_t,
53+
int16_t,
5254
int64_t,
5355
float,
5456
double,
@@ -62,6 +64,8 @@ PD_REGISTER_KERNEL(strided_slice_grad,
6264
phi::StridedSliceGradKernel,
6365
bool,
6466
int,
67+
int8_t,
68+
int16_t,
6569
int64_t,
6670
float,
6771
double,

test/indexing/test_getitem_appendix.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,16 @@ def test_combined(self):
230230
# case 6:
231231
# [[[4 , 5 ],[10, 11],[16, 17],[22, 23]]]
232232
self.accuracy_check(x[[True, False], :, -1], y[[True, False], :, -1])
233+
# case 7:
234+
# [[0, 3, 4, 5], [24, 26, 28, 29]]
235+
index_np = np.array([[True, False], [False, True], [True, True]])
236+
index_paddle = paddle.to_tensor(index_np)
237+
self.accuracy_check(x[:, 0, index_np], y[:, 0, index_paddle])
238+
# case 8:
239+
# [[[[0, 1]], [[2, 3]], [[24, 25]], [[26, 27]]]]
240+
index_np = np.array([[0], [1]])
241+
index_paddle = paddle.to_tensor(index_np)
242+
self.accuracy_check(x[:, 0, index_np], y[:, 0, index_paddle])
233243

234244

235245
class Test0DTensorIndexing(unittest.TestCase):

0 commit comments

Comments
 (0)