Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

supplement the function of slice. #34172

Merged
merged 12 commits into from
Aug 4, 2021
61 changes: 58 additions & 3 deletions paddle/fluid/operators/strided_slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ class StridedSliceOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "StridedSlice");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "StridedSlice");

auto input_var_type = ctx->GetInputsVarType("Input")[0];
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// shape is determined by Runtime.
return;
}
}
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(
in_dims.size(), 7,
Expand Down Expand Up @@ -154,6 +160,27 @@ class StridedSliceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto *in_var = ctx.InputVar("Input");
auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

结合下面的code,会不会有这种情况,lodtensorarray里面tensor的place是cuda_pinned

Copy link
Contributor Author

@hbwx24 hbwx24 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

if (is_in_var_array) {
auto &tensor_array = in_var->Get<framework::LoDTensorArray>();
for (auto &tensor : tensor_array) {
if (!platform::is_cuda_pinned_place(tensor.place())) {
PADDLE_ENFORCE_EQ(
platform::is_same_place(tensor.place(),
ctx.device_context().GetPlace()),
true,
platform::errors::InvalidArgument(
"Place of context is %s. Place of input tensor is %s. They "
"are should be same, but reveived different place.",
string::to_string(ctx.device_context().GetPlace()),
string::to_string(tensor.place())));
}
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
// NOTE: cuda pinned tensor need to copy its data to target place
auto in_tensor = ctx.Input<Tensor>("Input");
if (platform::is_cuda_pinned_place(in_tensor->place())) {
Expand All @@ -179,6 +206,14 @@ class StridedSliceOp : public framework::OperatorWithKernel {
}
};

class StridedSliceOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SetOutputType("Out", ctx->GetInputType("Input"));
ctx->SetOutputDataType("Out", ctx->GetInputDataType("Input"));
}
};

class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -259,6 +294,13 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "StridedSliceGrad");

auto input_var_type = ctx->GetInputsVarType("Input")[0];
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// shape is determined by Runtime
return;
}
}
auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) {
Expand Down Expand Up @@ -308,6 +350,16 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpMaker<T> {
bind->SetType("strided_slice_grad");
}
};
class StridedSliceGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SetOutputType(framework::GradVarName("Input"),
ctx->GetInputType(framework::GradVarName("Out")));
ctx->SetOutputDataType(
framework::GradVarName("Input"),
ctx->GetInputDataType(framework::GradVarName("Out")));
}
};

DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
"Input");
Expand All @@ -318,9 +370,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
namespace ops = paddle::operators;
REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>);
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>,
ops::StridedSliceOpVarTypeInference);

REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad,
ops::StridedSliceOpGradNoNeedBufferVarsInferer);
ops::StridedSliceOpGradNoNeedBufferVarsInferer,
ops::StridedSliceGradOpVarTypeInference);

REGISTER_OP_CPU_KERNEL(
strided_slice,
Expand Down
Loading