Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1601,19 +1601,10 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
&trans_dim,
&out_is_view);

bool has_bool_index = false;
for (auto& index : transed_index) {
if (index.dtype() == phi::DataType::BOOL) {
has_bool_index = true;
}
}
const int index_size = PyTuple_GET_SIZE(index_ptr);
const bool is_combined_bool = has_bool_index && index_size > 1;

ApplyGetitem(index_size,
pos_of_new_dim,
rank_of_new_dim,
is_combined_bool,
&transed_index,
&tensor,
&self->tensor,
Expand Down
141 changes: 94 additions & 47 deletions paddle/fluid/pybind/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <Python.h>

#include <algorithm>
#include <cstdint>
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/utils.h"
Expand All @@ -30,6 +31,7 @@
#include "paddle/phi/kernels/funcs/common_infer_shape_functions.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/utils/pybind.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
Expand Down Expand Up @@ -519,21 +521,31 @@ static void ParseIndex(const paddle::Tensor& tensor,
estimated_dim++;
}
} else {
*has_advanced_index = true;
if (slice_tensor.dtype() == phi::DataType::BOOL) {
PADDLE_ENFORCE_EQ(slice_tensor.shape()[0],
dim_len,
common::errors::OutOfRange(
"The shape of boolean index %d did not match"
"indexed tensor %d along axis %d.",
slice_tensor.shape()[0],
dim_len,
current_dim));
// bool tensor consumes (rank of index tensor) dimensions of input
// tensor
for (int i = 0; i < slice_tensor.shape().size(); i++) {
PADDLE_ENFORCE_EQ(slice_tensor.shape()[i],
dim_len,
common::errors::OutOfRange(
"The shape of boolean index %d did not match"
"indexed tensor %d along axis %d.",
slice_tensor.shape()[0],
dim_len,
current_dim));
(*advanced_index_dim)[estimated_dim] = estimated_dim;
estimated_dim++;
current_dim++;
dim_len = shape[current_dim];
}
} else {
// int tensor consumes only one dimension of input tensor
(*advanced_index_dim)[estimated_dim] = estimated_dim;
estimated_dim++;
current_dim++;
}
*has_advanced_index = true;
advanced_index->push_back(std::move(slice_tensor));
(*advanced_index_dim)[estimated_dim] = estimated_dim;
estimated_dim++;
current_dim++;
}

} else {
Expand Down Expand Up @@ -648,17 +660,14 @@ static paddle::Tensor dealWithAdvancedIndex(
int* rank_of_new_dim,
std::vector<int>* trans_dim,
bool* out_is_view) {
*rank_of_new_dim = 0;
int p = 0;
bool int_tensor_only = true;
for (size_t i = 0; i < advanced_index_dim->size(); ++i) {
auto index_dim = (*advanced_index_dim)[i];
if (index_dim != -1) {
// size of advanced_index is same to number of non -1 element in
// advanced_index_dim
// sum of each advanced_index_tensor's rank equals to number of non -1
// element in advanced_index_dim
auto index = (*advanced_index)[p++];
if (index.dtype() == phi::DataType::BOOL) {
int_tensor_only = false;
}

if (index_dim == 0) {
// case 1: advanced indices at axis 0, the new dim will be at first.
Expand All @@ -671,11 +680,23 @@ static paddle::Tensor dealWithAdvancedIndex(
} else {
*pos_of_new_dim = std::min(index_dim, *pos_of_new_dim);
}
*rank_of_new_dim =
std::max(*rank_of_new_dim, static_cast<int>(index.shape().size()));

trans_dim->push_back(index_dim);
transed_index->push_back(std::move(index));
if (index.dtype() == phi::DataType::BOOL) {
*rank_of_new_dim = std::max(*rank_of_new_dim, 1);
i--;
for (int j = 0; j < index.shape().size(); j++) {
i++;
index_dim = (*advanced_index_dim)[i];
trans_dim->push_back(index_dim);
}
transed_index->push_back(std::move(index));
} else {
*rank_of_new_dim =
std::max(*rank_of_new_dim, static_cast<int>(index.shape().size()));

trans_dim->push_back(index_dim);
transed_index->push_back(std::move(index));
}
}
}

Expand All @@ -695,8 +716,7 @@ static paddle::Tensor dealWithAdvancedIndex(
transed_tensor = tensor;
} else {
*out_is_view = true;
if (FLAGS_use_stride_kernel && *pos_of_new_dim != 0 &&
(is_for_setitem || int_tensor_only)) {
if (FLAGS_use_stride_kernel && *pos_of_new_dim != 0) {
transed_tensor = tensor;
} else {
transed_tensor = transpose_ad_func(tensor, *trans_dim);
Expand Down Expand Up @@ -731,9 +751,10 @@ static std::vector<paddle::Tensor> PrepareIndices(
}

static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
const paddle::Tensor& self_tensor,
const paddle::Tensor& bool_index,
const int64_t slice_offset,
const bool is_combined_bool) {
const int64_t pos_of_new_dim) {
PADDLE_ENFORCE(bool_index.shape().size() <= tensor.shape().size(),
common::errors::InvalidArgument(
"The dims of bool index doesn't match indexed array, "
Expand All @@ -743,34 +764,52 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
bool_index.shape().size()));
auto tensor_shape = tensor.shape();
size_t i = 0;
while (i < bool_index.shape().size()) {
PADDLE_ENFORCE_EQ(
bool_index.shape()[i],
tensor_shape[i],
common::errors::OutOfRange(
"The dimension of bool index doesn't match indexed array along "
"dimension %d, the target dimension is %d, but received %d",
i,
tensor_shape[i],
bool_index.shape()[i]));
i++;
if (FLAGS_use_stride_kernel) {
while (i < bool_index.shape().size()) {
PADDLE_ENFORCE_EQ(
bool_index.shape()[i],
tensor_shape[i + pos_of_new_dim],
common::errors::OutOfRange(
"The dimension of bool index doesn't match indexed array along "
"dimension %d, the target dimension is %d, but received %d",
i,
tensor_shape[i + pos_of_new_dim],
bool_index.shape()[i]));
i++;
}
} else {
while (i < bool_index.shape().size()) {
PADDLE_ENFORCE_EQ(
bool_index.shape()[i],
tensor_shape[i],
common::errors::OutOfRange(
"The dimension of bool index doesn't match indexed array along "
"dimension %d, the target dimension is %d, but received %d",
i,
tensor_shape[i],
bool_index.shape()[i]));
i++;
}
}

const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, tensor, bool_index)) {
ConvertAllInputsToDistTensor(mesh, tensor, bool_index);
if (InputsContainDistTensor(&mesh, tensor, self_tensor, bool_index)) {
ConvertAllInputsToDistTensor(mesh, tensor, self_tensor, bool_index);
}

if (bool_index.shape().size() == tensor_shape.size()) {
return masked_select_ad_func(tensor, bool_index);
}

auto bool_2_idx = nonzero_ad_func(bool_index);
if (FLAGS_use_stride_kernel && !is_combined_bool) {
if (FLAGS_use_stride_kernel) {
std::vector<paddle::Tensor> indices =
PrepareIndices(tensor, bool_2_idx, bool_index);
for (int i = 0; i < pos_of_new_dim; ++i) {
indices.insert(indices.begin(), paddle::Tensor());
}
while (indices.size() < static_cast<size_t>(tensor.dims().size())) {
indices.emplace_back();
indices.emplace_back(paddle::Tensor());
}

std::vector<paddle::Tensor> indices_int64;
Expand All @@ -784,7 +823,7 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
AdvancedIndex ad = AdvancedIndex(tensor, indices_int64);
const bool accumulate = false;

return index_elementwise_get_ad_func(tensor,
return index_elementwise_get_ad_func(self_tensor,
ad.indices,
ad.src_sizes,
ad.src_strides,
Expand Down Expand Up @@ -1172,7 +1211,6 @@ static void ApplySetitem(const std::vector<int> trans_dim,
static void ApplyGetitem(const int index_size,
const int pos_of_new_dim,
const int rank_of_new_dim,
const bool is_combined_bool,
std::vector<paddle::Tensor>* transed_index,
paddle::Tensor* tensor,
paddle::Tensor* self_tensor,
Expand Down Expand Up @@ -1201,9 +1239,18 @@ static void ApplyGetitem(const int index_size,
if (transed_index->size() == 1 &&
(*transed_index)[0].dtype() == phi::DataType::BOOL) {
// get value for bool tensor
int64_t slice_offset = 0;
*out = getValueForBoolTensor(
*transed_tensor, (*transed_index)[0], slice_offset, is_combined_bool);
const int64_t slice_offset =
reinterpret_cast<const char*>(transed_tensor->data()) -
reinterpret_cast<const char*>(self_tensor->data());
*out = getValueForBoolTensor(*transed_tensor,
(*self_tensor),
(*transed_index)[0],
slice_offset,
pos_of_new_dim);
if (!FLAGS_use_stride_kernel) {
handle_transpose(*out);
}
return;
} else {
// get value for int tensor
ParseBoolAndBroadcastIndices(transed_index);
Expand All @@ -1215,14 +1262,15 @@ static void ApplyGetitem(const int index_size,
}
}

if (FLAGS_use_stride_kernel && !is_combined_bool && !has_empty_index) {
if (FLAGS_use_stride_kernel && !has_empty_index) {
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(
&mesh, *self_tensor, *transed_tensor, *transed_index)) {
ConvertAllInputsToDistTensor(
mesh, *self_tensor, *transed_tensor, *transed_index);
}

*transed_index = expandTensors(*transed_index);
*transed_index = expand_outplace(*transed_index);

std::vector<paddle::Tensor> transed_index_int64;
Expand Down Expand Up @@ -1277,7 +1325,6 @@ static void ApplyGetitem(const int index_size,
return;
}
}

handle_transpose(*out);
}

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad,
bool,
int64_t,
int16_t,
int8_t,
int,
phi::dtype::float16,
phi::dtype::bfloat16,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/strided_slice_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ PD_REGISTER_KERNEL(strided_slice_grad,
phi::StridedSliceGradKernel,
bool,
int,
int8_t,
int16_t,
int64_t,
float,
double,
Expand All @@ -62,6 +64,8 @@ PD_REGISTER_KERNEL(strided_slice_grad,
phi::StridedSliceGradKernel,
bool,
int,
int8_t,
int16_t,
int64_t,
float,
double,
Expand Down
10 changes: 10 additions & 0 deletions test/indexing/test_getitem_appendix.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,16 @@ def test_combined(self):
# case 6:
# [[[4 , 5 ],[10, 11],[16, 17],[22, 23]]]
self.accuracy_check(x[[True, False], :, -1], y[[True, False], :, -1])
# case 7:
# [[0, 3, 4, 5], [24, 26, 28, 29]]
index_np = np.array([[True, False], [False, True], [True, True]])
index_paddle = paddle.to_tensor(index_np)
self.accuracy_check(x[:, 0, index_np], y[:, 0, index_paddle])
# case 8:
# [[[[0, 1]], [[2, 3]], [[24, 25]], [[26, 27]]]]
index_np = np.array([[0], [1]])
index_paddle = paddle.to_tensor(index_np)
self.accuracy_check(x[:, 0, index_np], y[:, 0, index_paddle])


class Test0DTensorIndexing(unittest.TestCase):
Expand Down
Loading