1717#include < Python.h>
1818
1919#include < algorithm>
20+ #include < cstdint>
2021#include " paddle/fluid/eager/api/all.h"
2122#include " paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
2223#include " paddle/fluid/eager/utils.h"
3031#include " paddle/phi/kernels/funcs/common_infer_shape_functions.h"
3132#include " paddle/phi/kernels/funcs/slice_utils.h"
3233#include " paddle/phi/kernels/funcs/strided_slice.h"
34+ #include " paddle/utils/pybind.h"
3335#include " pybind11/numpy.h"
3436#include " pybind11/pybind11.h"
3537#include " pybind11/stl.h"
@@ -519,21 +521,31 @@ static void ParseIndex(const paddle::Tensor& tensor,
519521 estimated_dim++;
520522 }
521523 } else {
524+ *has_advanced_index = true ;
522525 if (slice_tensor.dtype () == phi::DataType::BOOL) {
523- PADDLE_ENFORCE_EQ (slice_tensor.shape ()[0 ],
524- dim_len,
525- common::errors::OutOfRange (
526- " The shape of boolean index %d did not match"
527- " indexed tensor %d along axis %d." ,
528- slice_tensor.shape ()[0 ],
529- dim_len,
530- current_dim));
526+ // bool tensor consumes (rank of index tensor) dimensions of input
527+ // tensor
528+ for (int i = 0 ; i < slice_tensor.shape ().size (); i++) {
529+ PADDLE_ENFORCE_EQ (slice_tensor.shape ()[i],
530+ dim_len,
531+ common::errors::OutOfRange (
532+ " The shape of boolean index %d did not match"
533+ " indexed tensor %d along axis %d." ,
534+ slice_tensor.shape ()[0 ],
535+ dim_len,
536+ current_dim));
537+ (*advanced_index_dim)[estimated_dim] = estimated_dim;
538+ estimated_dim++;
539+ current_dim++;
540+ dim_len = shape[current_dim];
541+ }
542+ } else {
543+ // int tensor consumes only one dimension of input tensor
544+ (*advanced_index_dim)[estimated_dim] = estimated_dim;
545+ estimated_dim++;
546+ current_dim++;
531547 }
532- *has_advanced_index = true ;
533548 advanced_index->push_back (std::move (slice_tensor));
534- (*advanced_index_dim)[estimated_dim] = estimated_dim;
535- estimated_dim++;
536- current_dim++;
537549 }
538550
539551 } else {
@@ -648,17 +660,14 @@ static paddle::Tensor dealWithAdvancedIndex(
648660 int * rank_of_new_dim,
649661 std::vector<int >* trans_dim,
650662 bool * out_is_view) {
663+ *rank_of_new_dim = 0 ;
651664 int p = 0 ;
652- bool int_tensor_only = true ;
653665 for (size_t i = 0 ; i < advanced_index_dim->size (); ++i) {
654666 auto index_dim = (*advanced_index_dim)[i];
655667 if (index_dim != -1 ) {
656- // size of advanced_index is same to number of non -1 element in
657- // advanced_index_dim
668+ // sum of each advanced_index_tensor's rank equals to number of non -1
669+ // element in advanced_index_dim
658670 auto index = (*advanced_index)[p++];
659- if (index.dtype () == phi::DataType::BOOL) {
660- int_tensor_only = false ;
661- }
662671
663672 if (index_dim == 0 ) {
664673 // case 1: advanced indices at axis 0, the new dim will be at first.
@@ -671,11 +680,23 @@ static paddle::Tensor dealWithAdvancedIndex(
671680 } else {
672681 *pos_of_new_dim = std::min (index_dim, *pos_of_new_dim);
673682 }
674- *rank_of_new_dim =
675- std::max (*rank_of_new_dim, static_cast <int >(index.shape ().size ()));
676683
677- trans_dim->push_back (index_dim);
678- transed_index->push_back (std::move (index));
684+ if (index.dtype () == phi::DataType::BOOL) {
685+ *rank_of_new_dim = std::max (*rank_of_new_dim, 1 );
686+ i--;
687+ for (int j = 0 ; j < index.shape ().size (); j++) {
688+ i++;
689+ index_dim = (*advanced_index_dim)[i];
690+ trans_dim->push_back (index_dim);
691+ }
692+ transed_index->push_back (std::move (index));
693+ } else {
694+ *rank_of_new_dim =
695+ std::max (*rank_of_new_dim, static_cast <int >(index.shape ().size ()));
696+
697+ trans_dim->push_back (index_dim);
698+ transed_index->push_back (std::move (index));
699+ }
679700 }
680701 }
681702
@@ -695,8 +716,7 @@ static paddle::Tensor dealWithAdvancedIndex(
695716 transed_tensor = tensor;
696717 } else {
697718 *out_is_view = true ;
698- if (FLAGS_use_stride_kernel && *pos_of_new_dim != 0 &&
699- (is_for_setitem || int_tensor_only)) {
719+ if (FLAGS_use_stride_kernel && *pos_of_new_dim != 0 ) {
700720 transed_tensor = tensor;
701721 } else {
702722 transed_tensor = transpose_ad_func (tensor, *trans_dim);
@@ -731,9 +751,10 @@ static std::vector<paddle::Tensor> PrepareIndices(
731751}
732752
733753static paddle::Tensor getValueForBoolTensor (const paddle::Tensor& tensor,
754+ const paddle::Tensor& self_tensor,
734755 const paddle::Tensor& bool_index,
735756 const int64_t slice_offset,
736- const bool is_combined_bool ) {
757+ const int64_t pos_of_new_dim ) {
737758 PADDLE_ENFORCE (bool_index.shape ().size () <= tensor.shape ().size (),
738759 common::errors::InvalidArgument (
739760 " The dims of bool index doesn't match indexed array, "
@@ -743,34 +764,52 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
743764 bool_index.shape ().size ()));
744765 auto tensor_shape = tensor.shape ();
745766 size_t i = 0 ;
746- while (i < bool_index.shape ().size ()) {
747- PADDLE_ENFORCE_EQ (
748- bool_index.shape ()[i],
749- tensor_shape[i],
750- common::errors::OutOfRange (
751- " The dimension of bool index doesn't match indexed array along "
752- " dimension %d, the target dimension is %d, but received %d" ,
753- i,
754- tensor_shape[i],
755- bool_index.shape ()[i]));
756- i++;
767+ if (FLAGS_use_stride_kernel) {
768+ while (i < bool_index.shape ().size ()) {
769+ PADDLE_ENFORCE_EQ (
770+ bool_index.shape ()[i],
771+ tensor_shape[i + pos_of_new_dim],
772+ common::errors::OutOfRange (
773+ " The dimension of bool index doesn't match indexed array along "
774+ " dimension %d, the target dimension is %d, but received %d" ,
775+ i,
776+ tensor_shape[i + pos_of_new_dim],
777+ bool_index.shape ()[i]));
778+ i++;
779+ }
780+ } else {
781+ while (i < bool_index.shape ().size ()) {
782+ PADDLE_ENFORCE_EQ (
783+ bool_index.shape ()[i],
784+ tensor_shape[i],
785+ common::errors::OutOfRange (
786+ " The dimension of bool index doesn't match indexed array along "
787+ " dimension %d, the target dimension is %d, but received %d" ,
788+ i,
789+ tensor_shape[i],
790+ bool_index.shape ()[i]));
791+ i++;
792+ }
757793 }
758794
759795 const phi::distributed::ProcessMesh* mesh = nullptr ;
760- if (InputsContainDistTensor (&mesh, tensor, bool_index)) {
761- ConvertAllInputsToDistTensor (mesh, tensor, bool_index);
796+ if (InputsContainDistTensor (&mesh, tensor, self_tensor, bool_index)) {
797+ ConvertAllInputsToDistTensor (mesh, tensor, self_tensor, bool_index);
762798 }
763799
764800 if (bool_index.shape ().size () == tensor_shape.size ()) {
765801 return masked_select_ad_func (tensor, bool_index);
766802 }
767803
768804 auto bool_2_idx = nonzero_ad_func (bool_index);
769- if (FLAGS_use_stride_kernel && !is_combined_bool ) {
805+ if (FLAGS_use_stride_kernel) {
770806 std::vector<paddle::Tensor> indices =
771807 PrepareIndices (tensor, bool_2_idx, bool_index);
808+ for (int i = 0 ; i < pos_of_new_dim; ++i) {
809+ indices.insert (indices.begin (), paddle::Tensor ());
810+ }
772811 while (indices.size () < static_cast <size_t >(tensor.dims ().size ())) {
773- indices.emplace_back ();
812+ indices.emplace_back (paddle::Tensor () );
774813 }
775814
776815 std::vector<paddle::Tensor> indices_int64;
@@ -784,7 +823,7 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
784823 AdvancedIndex ad = AdvancedIndex (tensor, indices_int64);
785824 const bool accumulate = false ;
786825
787- return index_elementwise_get_ad_func (tensor ,
826+ return index_elementwise_get_ad_func (self_tensor ,
788827 ad.indices ,
789828 ad.src_sizes ,
790829 ad.src_strides ,
@@ -1173,7 +1212,6 @@ static void ApplySetitem(const std::vector<int> trans_dim,
11731212static void ApplyGetitem (const int index_size,
11741213 const int pos_of_new_dim,
11751214 const int rank_of_new_dim,
1176- const bool is_combined_bool,
11771215 std::vector<paddle::Tensor>* transed_index,
11781216 paddle::Tensor* tensor,
11791217 paddle::Tensor* self_tensor,
@@ -1202,9 +1240,18 @@ static void ApplyGetitem(const int index_size,
12021240 if (transed_index->size () == 1 &&
12031241 (*transed_index)[0 ].dtype () == phi::DataType::BOOL) {
12041242 // get value for bool tensor
1205- int64_t slice_offset = 0 ;
1206- *out = getValueForBoolTensor (
1207- *transed_tensor, (*transed_index)[0 ], slice_offset, is_combined_bool);
1243+ const int64_t slice_offset =
1244+ reinterpret_cast <const char *>(transed_tensor->data ()) -
1245+ reinterpret_cast <const char *>(self_tensor->data ());
1246+ *out = getValueForBoolTensor (*transed_tensor,
1247+ (*self_tensor),
1248+ (*transed_index)[0 ],
1249+ slice_offset,
1250+ pos_of_new_dim);
1251+ if (!FLAGS_use_stride_kernel) {
1252+ handle_transpose (*out);
1253+ }
1254+ return ;
12081255 } else {
12091256 // get value for int tensor
12101257 ParseBoolAndBroadcastIndices (transed_index);
@@ -1216,14 +1263,15 @@ static void ApplyGetitem(const int index_size,
12161263 }
12171264 }
12181265
1219- if (FLAGS_use_stride_kernel && !is_combined_bool && ! has_empty_index) {
1266+ if (FLAGS_use_stride_kernel && !has_empty_index) {
12201267 const phi::distributed::ProcessMesh* mesh = nullptr ;
12211268 if (InputsContainDistTensor (
12221269 &mesh, *self_tensor, *transed_tensor, *transed_index)) {
12231270 ConvertAllInputsToDistTensor (
12241271 mesh, *self_tensor, *transed_tensor, *transed_index);
12251272 }
12261273
1274+ *transed_index = expandTensors (*transed_index);
12271275 *transed_index = expand_outplace (*transed_index);
12281276
12291277 std::vector<paddle::Tensor> transed_index_int64;
@@ -1278,7 +1326,6 @@ static void ApplyGetitem(const int index_size,
12781326 return ;
12791327 }
12801328 }
1281-
12821329 handle_transpose (*out);
12831330}
12841331
0 commit comments