Skip to content

Commit

Permalink
supplement the function of slice. (#34172)
Browse files Browse the repository at this point in the history
* supplement the function of slice

* edit unittest

* strided_slice_op support .

* polish error message.

* polish error message.

* polish code.

* polish unittest.

* polish code.

* polish code

* polish error message.
  • Loading branch information
hbwx24 authored Aug 4, 2021
1 parent c79fa1c commit 1f0f5d3
Show file tree
Hide file tree
Showing 5 changed files with 696 additions and 65 deletions.
61 changes: 58 additions & 3 deletions paddle/fluid/operators/strided_slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ class StridedSliceOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "StridedSlice");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "StridedSlice");

auto input_var_type = ctx->GetInputsVarType("Input")[0];
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// shape is determined by Runtime.
return;
}
}
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(
in_dims.size(), 7,
Expand Down Expand Up @@ -154,6 +160,27 @@ class StridedSliceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto *in_var = ctx.InputVar("Input");
auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>();
if (is_in_var_array) {
auto &tensor_array = in_var->Get<framework::LoDTensorArray>();
for (auto &tensor : tensor_array) {
if (!platform::is_cuda_pinned_place(tensor.place())) {
PADDLE_ENFORCE_EQ(
platform::is_same_place(tensor.place(),
ctx.device_context().GetPlace()),
true,
platform::errors::InvalidArgument(
"Place of context is %s. Place of input tensor is %s. They "
"are should be same, but reveived different place.",
string::to_string(ctx.device_context().GetPlace()),
string::to_string(tensor.place())));
}
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
// NOTE: cuda pinned tensor need to copy its data to target place
auto in_tensor = ctx.Input<Tensor>("Input");
if (platform::is_cuda_pinned_place(in_tensor->place())) {
Expand All @@ -179,6 +206,14 @@ class StridedSliceOp : public framework::OperatorWithKernel {
}
};

class StridedSliceOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SetOutputType("Out", ctx->GetInputType("Input"));
ctx->SetOutputDataType("Out", ctx->GetInputDataType("Input"));
}
};

class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -259,6 +294,13 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "StridedSliceGrad");

auto input_var_type = ctx->GetInputsVarType("Input")[0];
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// shape is determined by Runtime
return;
}
}
auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) {
Expand Down Expand Up @@ -308,6 +350,16 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpMaker<T> {
bind->SetType("strided_slice_grad");
}
};
class StridedSliceGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SetOutputType(framework::GradVarName("Input"),
ctx->GetInputType(framework::GradVarName("Out")));
ctx->SetOutputDataType(
framework::GradVarName("Input"),
ctx->GetInputDataType(framework::GradVarName("Out")));
}
};

DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
"Input");
Expand All @@ -318,9 +370,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
namespace ops = paddle::operators;
REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>);
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>,
ops::StridedSliceOpVarTypeInference);

REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad,
ops::StridedSliceOpGradNoNeedBufferVarsInferer);
ops::StridedSliceOpGradNoNeedBufferVarsInferer,
ops::StridedSliceGradOpVarTypeInference);

REGISTER_OP_CPU_KERNEL(
strided_slice,
Expand Down
Loading

0 comments on commit 1f0f5d3

Please sign in to comment.