Skip to content

Commit

Permalink
Optimize advanced setting by remove the last set_value (#60771)
Browse files Browse the repository at this point in the history
* pure-advanced setitem will not set_value back

* fix multi output in tensor_array_pir

* only in dynamic mode

* add only advanced-setting case to fix coverage
  • Loading branch information
zoooo0820 authored Jan 18, 2024
1 parent 4760488 commit 1251a32
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 48 deletions.
47 changes: 26 additions & 21 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
&use_strided_slice);

// step2: Dealing with basic indexing
bool out_is_view = false;
auto out = getTensorWithBasicIndexing(tensor,
&slice_axes,
&slice_starts,
Expand All @@ -1377,7 +1378,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
&decrease_axis,
&none_axes,
&infer_flags,
&use_strided_slice);
&use_strided_slice,
&out_is_view);

if (!has_advanced_index) {
return ToPyObject(out);
Expand All @@ -1396,7 +1398,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
&trans_back_dim,
&pos_of_new_dim,
&rank_of_new_dim,
&trans_dim);
&trans_dim,
&out_is_view);

if (transed_index.size() == 1 &&
transed_index[0].dtype() == phi::DataType::BOOL) {
Expand Down Expand Up @@ -1691,6 +1694,7 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
// 3. assign values to the sliced result by index_put OP;
// 4. transpose back and assign the result to original tensor by set_value
// OP.
bool out_is_view = false;
paddle::Tensor sub_tensor = getTensorWithBasicIndexing(tensor,
&slice_axes,
&slice_starts,
Expand All @@ -1699,7 +1703,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&decrease_axis,
&none_axes,
&infer_flags,
&use_strided_slice);
&use_strided_slice,
&out_is_view);

std::vector<paddle::Tensor> transed_index;
std::vector<int> trans_back_dim, trans_dim;
Expand All @@ -1715,7 +1720,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&trans_back_dim,
&pos_of_new_dim,
&rank_of_new_dim,
&trans_dim);
&trans_dim,
&out_is_view);

// Release gil and do tracing
py::gil_scoped_release release;
Expand All @@ -1742,10 +1748,6 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
value_tensor = transpose_ad_func(value_tensor, trans_dim);
}

// TODO(zoooo0820) 1.Using inplace version index_put
// 2.Remove following code after backward bug fixed.
transed_sub_tensor = assign_ad_func(transed_sub_tensor);

const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(
&mesh, self->tensor, transed_sub_tensor, value_tensor)) {
Expand All @@ -1754,19 +1756,22 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
}

transed_sub_tensor =
index_put_ad_func(transed_sub_tensor, transed_index, value_tensor);

paddle::Tensor transback_sub_tensor =
transpose_ad_func(transed_sub_tensor, trans_back_dim);

self->tensor = set_value_with_tensor__ad_func(self->tensor,
transback_sub_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
index_put__ad_func(transed_sub_tensor, transed_index, value_tensor);

// TODO(zoooo0820) Remove following code after backward bug fixed.
if (out_is_view) {
paddle::Tensor transback_sub_tensor =
transpose_ad_func(transed_sub_tensor, trans_back_dim);

self->tensor = set_value_with_tensor__ad_func(self->tensor,
transback_sub_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
}
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
Expand Down
9 changes: 7 additions & 2 deletions paddle/fluid/pybind/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,13 @@ static paddle::Tensor getTensorWithBasicIndexing(
std::vector<int64_t>* decrease_axis,
std::vector<int64_t>* none_axes,
std::vector<int64_t>* infer_flags,
bool* use_strided_slice) {
bool* use_strided_slice,
bool* out_is_view) {
paddle::Tensor out;
if (slice_axes->empty()) {
out = tensor;
} else {
*out_is_view = true;
if (!(*use_strided_slice)) {
eager_gil_scoped_release guard;
out = slice_ad_func(tensor,
Expand All @@ -373,6 +375,7 @@ static paddle::Tensor getTensorWithBasicIndexing(
}
}
if (!none_axes->empty()) {
*out_is_view = true;
eager_gil_scoped_release guard;
// Deal with cases that decrease_axes is not empty
// For example:
Expand Down Expand Up @@ -401,7 +404,8 @@ static paddle::Tensor dealWithAdvancedIndex(
std::vector<int>* trans_back_dim,
int* pos_of_new_dim,
int* rank_of_new_dim,
std::vector<int>* trans_dim) {
std::vector<int>* trans_dim,
bool* out_is_view) {
int p = 0;
for (size_t i = 0; i < advanced_index_dim->size(); ++i) {
auto index_dim = (*advanced_index_dim)[i];
Expand Down Expand Up @@ -444,6 +448,7 @@ static paddle::Tensor dealWithAdvancedIndex(
if (original_dim_order == *trans_dim) {
transed_tensor = tensor;
} else {
*out_is_view = true;
transed_tensor = transpose_ad_func(tensor, *trans_dim);
}

Expand Down
58 changes: 33 additions & 25 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import numpy as np

Expand Down Expand Up @@ -170,7 +169,9 @@ def _setitem_for_tensor_array(var, item, value):
)


def deal_advanced_index(ori_tensor, indices, is_for_setitem, values):
def deal_advanced_index(
ori_tensor, indices, is_for_setitem, values, out_is_view=True
):
"""
Transpose origin Tensor and advanced indices to the front.
Expand Down Expand Up @@ -206,18 +207,24 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem, values):
for i in range(ori_tensor.ndim):
if indices[i] is None:
transed_dim.append(i)
transed_tensor = ori_tensor.transpose(transed_dim)

trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else []

transed_value_tensor = None
if is_for_setitem:
if values.ndim > 1 and pos_of_new_dim != 0:
# If the value tensor is not a scalar / 1-D Tensor, and the src tensor was
# transposed at 1st dim, the value tensor should be transposed too.
transed_value_tensor = values.transpose(transed_dim)
else:

if transed_dim == list(range(ori_tensor.ndim)):
transed_tensor = ori_tensor
if is_for_setitem:
transed_value_tensor = values
else:
out_is_view = True
transed_tensor = ori_tensor.transpose(transed_dim)
if is_for_setitem:
if values.ndim > 1 and pos_of_new_dim != 0:
# If the value tensor is not a scalar / 1-D Tensor, and the src tensor was
# transposed at 1st dim, the value tensor should be transposed too.
transed_value_tensor = values.transpose(transed_dim)
else:
transed_value_tensor = values

return (
transed_tensor,
Expand All @@ -226,6 +233,7 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem, values):
pos_of_new_dim,
rank_of_new_dim,
transed_value_tensor,
out_is_view,
)


Expand Down Expand Up @@ -599,7 +607,7 @@ def _setitem_static(x, indices, values):
):
values = paddle.assign(values).astype(x.dtype)

sub_tensor = get_tensor_with_basic_indexing(
sub_tensor, is_view = get_tensor_with_basic_indexing(
x,
axes,
starts,
Expand All @@ -616,16 +624,21 @@ def _setitem_static(x, indices, values):
_,
_,
values,
) = deal_advanced_index(sub_tensor, advanced_index, True, values)
is_view,
) = deal_advanced_index(
sub_tensor, advanced_index, True, values, is_view
)

if values.dtype != transed_sub_tensor.dtype:
values = values.astype(transed_sub_tensor.dtype)

if in_dynamic_or_pir_mode():
if paddle.in_dynamic_mode():
# NOTE(zoooo0820): directly return result instead of another set_value, after backward bug fixed.
transed_sub_tensor = transed_sub_tensor.index_put_(
adjusted_advanced_index, values
)
if not is_view:
return transed_sub_tensor
else:
transed_sub_tensor = transed_sub_tensor.index_put(
adjusted_advanced_index, values
Expand Down Expand Up @@ -694,12 +707,14 @@ def get_tensor_with_basic_indexing(
):
from .dygraph.base import in_to_static_mode

out_is_view = False
if in_to_static_mode() and hasattr(x, "is_view_var"):
x.is_view_var = True

if len(axes) == 0:
out = x
else:
out_is_view = True
op_type = "strided_slice" if use_strided_slice else "slice"
inputs = {'Input': [x]}
attrs = {
Expand Down Expand Up @@ -748,7 +763,7 @@ def get_tensor_with_basic_indexing(
if paddle.utils._contain_var(end):
end = paddle.utils.get_int_tensor_list(end)
if x.is_dense_tensor_array_type():
return paddle._pir_ops.slice_array_dense(x, st)
return paddle._pir_ops.slice_array_dense(x, st), False
out = paddle._C_ops.slice(
x,
axes,
Expand All @@ -775,17 +790,9 @@ def get_tensor_with_basic_indexing(
attrs=attrs,
)
out = slice_out_var
# NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
# with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
# otherwise the output shape will be not correct.
set_to_1d = paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']
if set_to_1d and len(decrease_axes) == len(x.shape):
warnings.warn(
"Warning: In Tensor '__getitem__', if the number of scalar elements in the index is equal to the rank of the Tensor, the output should be 0-D. In order to be consistent with the behavior of previous versions, it will be processed to 1-D. But it is not correct and will be removed in release 2.6. If 1-D is still wanted, please modify the index element from scalar to slice (e.g. 'x[i]' => 'x[i:i+1]')."
)
none_axes = none_axes[1:]

if len(none_axes) > 0:
out_is_view = True
# Deal with cases that decrease_axes is not empty
# For example:
# # x.shape: (2,3,4)
Expand All @@ -799,7 +806,7 @@ def get_tensor_with_basic_indexing(

if in_to_static_mode() and hasattr(out, "is_view_var"):
out.is_view_var = True
return out
return out, out_is_view


def _getitem_static(x, indices):
Expand All @@ -822,7 +829,7 @@ def _getitem_static(x, indices):
) = parse_index(x, indices)

# step2: Dealing with basic indexing
out = get_tensor_with_basic_indexing(
out, _ = get_tensor_with_basic_indexing(
x,
axes,
starts,
Expand All @@ -842,6 +849,7 @@ def _getitem_static(x, indices):
pos_of_new_dim,
rank_of_new_dim,
_,
_,
) = deal_advanced_index(out, advanced_index, False, None)

# TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently
Expand Down
29 changes: 29 additions & 0 deletions test/indexing/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ def setUp(self):
self.ndtype = np.float64
self.dtype = 'float64'

def test_advanced_index(self):
np_data = np.zeros((3, 4, 5, 6), dtype='float32').astype(self.ndtype)
if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

x = paddle.to_tensor(np_data, dtype=self.dtype)
np_data[[0, 1], [1, 2], [1]] = 10.0
x[[0, 1], [1, 2], [1]] = 10.0

if self.dtype == 'bfloat16':
x = paddle.cast(x, dtype='float32')
np.testing.assert_allclose(x.numpy(), np_data)

def test_combined_index_1(self):
np_data = np.zeros((3, 4, 5, 6), dtype='float32').astype(self.ndtype)
if self.dtype == 'bfloat16':
Expand Down Expand Up @@ -426,6 +441,20 @@ def setUp(self):
paddle.enable_static()
self.exe = paddle.static.Executor()

@test_with_pir_api
def test_advanced_index(self):
# multi-int tensor
np_data = np.zeros((3, 4, 5, 6), dtype='float32')
np_data[[0, 1], [1, 2], [1]] = 10.0
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.zeros((3, 4, 5, 6), dtype='float32')
y = _setitem_static(x, ([0, 1], [1, 2], [1]), 10.0)
res = self.exe.run(fetch_list=[y])

np.testing.assert_allclose(res[0], np_data)

@test_with_pir_api
def test_combined_index_1(self):
# int tensor + slice (without decreasing axes)
Expand Down

0 comments on commit 1251a32

Please sign in to comment.