Skip to content

Commit

Permalink
[MLU] set_value performance optimizing
Browse files Browse the repository at this point in the history
  • Loading branch information
fuyou765 committed Jul 19, 2022
1 parent 068f48d commit 63a148f
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 42 deletions.
97 changes: 55 additions & 42 deletions paddle/fluid/operators/set_value_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <numeric>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/set_value_op.h"
Expand Down Expand Up @@ -62,7 +63,6 @@ class SetValueMLUKernel : public framework::OpKernel<T> {
auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
std::vector<int64_t> slice_dims_with_none;

size_t none_axes_cur = 0, decrease_axes_cur = 0;
for (int i = 0; i < slice_dims.size(); ++i) {
while (none_axes_cur < none_axes.size() &&
Expand All @@ -84,51 +84,22 @@ class SetValueMLUKernel : public framework::OpKernel<T> {

slice_dims_for_assign = phi::make_ddim(slice_dims_with_none);
}

auto starts_indices = std::vector<int64_t>(in_dims.size(), 0);
auto ends_indices = std::vector<int64_t>(in_dims.size(), 0);
auto strides_indices = std::vector<int64_t>(in_dims.size(), 0);
int in_size = in_dims.size();
int starts_indices[in_size] = {0};
int ends_indices[in_size] = {0};
int strides_indices[in_size] = {0};

for (int i = 0; i < in_dims.size(); ++i) {
starts_indices[i] = 0;
ends_indices[i] = slice_dims[i];
ends_indices[i] = static_cast<int>(slice_dims[i]);
strides_indices[i] = 1;
}
for (size_t i = 0; i < axes.size(); i++) {
int axis_index = axes[i];
starts_indices[axis_index] = starts[i];
ends_indices[axis_index] = ends[i];
strides_indices[axis_index] = steps[i];
}

int64_t stride_step = phi::product(in_dims);
std::vector<int64_t> index_indices(1, 0);
for (size_t i = 0; i < strides_indices.size(); ++i) {
auto index_size = index_indices.size();
stride_step /= in_dims[i];
for (size_t j = 0; j < index_size; ++j) {
auto start_index = *index_indices.begin();
if (strides_indices[i] > 0) {
for (int64_t k = starts_indices[i]; k < ends_indices[i];
k += strides_indices[i]) {
index_indices.push_back(start_index + k * stride_step);
}
} else {
for (int64_t k = starts_indices[i]; k > ends_indices[i];
k += strides_indices[i]) {
index_indices.push_back(start_index + k * stride_step);
}
}
index_indices.erase(index_indices.begin());
}
starts_indices[axis_index] = static_cast<int>(starts[i]);
ends_indices[axis_index] = static_cast<int>(ends[i]);
strides_indices[axis_index] = static_cast<int>(steps[i]);
}

PADDLE_ENFORCE_EQ(
static_cast<int64_t>(index_indices.size()),
phi::product(slice_dims_for_assign),
platform::errors::InvalidArgument(
"OP(set_value) error index indices and value update not match "));

Tensor value_t(in->type());
if (value_tensor != nullptr) {
value_t.ShareDataWith(*value_tensor);
Expand Down Expand Up @@ -160,29 +131,71 @@ class SetValueMLUKernel : public framework::OpKernel<T> {

int64_t input_numel = phi::product(in_dims);
int64_t value_numel = phi::product(value_temp.dims());
Tensor in_temp, out_temp, val_temp;
Tensor in_temp, out_temp, val_temp, index_out;
int64_t stride_step = phi::product(in_dims);
std::vector<int64_t> index_indices(stride_step);
std::iota(index_indices.begin(), index_indices.end(), 0);
framework::Tensor index_temp;
in_temp.ShareDataWith(*in);
val_temp.ShareDataWith(value_temp);
paddle::framework::TensorFromVector(
index_indices, ctx.device_context(), &index_temp);
index_temp.Resize(in_dims);
auto index_dims = in_dims;
for (int i = 0; i < in_dims.size(); ++i) {
if (starts_indices[i] < 0 || ends_indices[i] < 0) {
starts_indices[i] -= in_dims[i];
ends_indices[i] -= in_dims[i];
}
if (strides_indices[i] > 0)
index_dims[i] =
static_cast<int>((ends_indices[i] - starts_indices[i] - 1) /
strides_indices[i]) +
1;
else
index_dims[i] =
static_cast<int>((ends_indices[i] - starts_indices[i] + 1) /
strides_indices[i]) +
1;
}
auto new_in_dims = phi::make_ddim({input_numel});
auto new_val_dims = phi::make_ddim({value_numel});
in_temp.Resize(new_in_dims);
val_temp.Resize(new_val_dims);
index_out.Resize(index_dims);
index_out.mutable_data<int64_t>(ctx.GetPlace());
cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE;
MLUCnnlTensorDesc x_desc(in_temp);
MLUCnnlTensorDesc indices_desc(index_temp);
MLUCnnlTensorDesc indices_out_desc(index_out);
MLUCnnlTensorDesc updates_desc(val_temp);
MLUCnnlTensorDesc out_desc(*out);

MLUCnnl::StridedSlice(ctx,
starts_indices,
ends_indices,
strides_indices,
indices_desc.get(),
GetBasePtr(&index_temp),
indices_out_desc.get(),
GetBasePtr(&index_out));
PADDLE_ENFORCE_EQ(
static_cast<int64_t>(phi::product(index_out.dims())),
phi::product(slice_dims_for_assign),
platform::errors::InvalidArgument(
"OP(set_value) error index indices and value update not match "));
Tensor index_final;
index_final.ShareDataWith(index_out);
int64_t indices_numel = phi::product(index_dims);
auto new_index_dims = phi::make_ddim({indices_numel});
index_final.Resize(new_index_dims);
MLUCnnlTensorDesc indices_final_desc(index_final);
MLUCnnl::ScatterRefFunctor(ctx,
x_desc.get(),
GetBasePtr(&in_temp),
updates_desc.get(),
GetBasePtr(&val_temp),
indices_desc.get(),
GetBasePtr(&index_temp),
indices_final_desc.get(),
GetBasePtr(&index_final),
mode);
in_temp.Resize(in_dims);
paddle::framework::TensorCopy(in_temp, ctx.GetPlace(), out);
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/fluid/tests/unittests/mlu/test_set_value_op_mlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ def _get_answer(self):
self.data[0:, 1:2, :] = self.value


class TestSetValueItemSlice5(TestSetValueApi):

def set_shape(self):
self.shape = [100, 426, 640]

def _call_setitem(self, x):
x[0:-1] = self.value

def _get_answer(self):
self.data[0:-1] = self.value


#TODO: Fix this after MLU support while_loop
#class TestSetValueItemSliceInWhile(TestSetValueApi):
# def _call_setitem(self, x):
Expand Down Expand Up @@ -517,6 +529,7 @@ def set_dtype(self):
create_test_value_int32(TestSetValueItemSlice2)
create_test_value_int32(TestSetValueItemSlice3)
create_test_value_int32(TestSetValueItemSlice4)
create_test_value_int32(TestSetValueItemSlice5)


def create_test_value_tensor_fp32(parent):
Expand All @@ -543,6 +556,7 @@ def _get_answer(self):
create_test_value_tensor_fp32(TestSetValueItemSlice2)
create_test_value_tensor_fp32(TestSetValueItemSlice3)
create_test_value_tensor_fp32(TestSetValueItemSlice4)
create_test_value_tensor_fp32(TestSetValueItemSlice5)


# 3. Test different shape of value
Expand Down

0 comments on commit 63a148f

Please sign in to comment.