Skip to content

Commit

Permalink
fix cinn graph may hasn't input problem (#40814)
Browse files Browse the repository at this point in the history
  • Loading branch information
thisjiang authored Mar 23, 2022
1 parent db41e39 commit 17b8335
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
51 changes: 50 additions & 1 deletion paddle/fluid/operators/cinn/cinn_instruction_run_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun");
// The cinn-graph may hasn't input for CINN now support fill_constant,
// and its all inputs may generated by fill_constant instead of by fetch.
// OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnInstructionRun");
const CinnCompiledObject& compiled_object =
Expand All @@ -43,6 +45,53 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel {
});
ctx->SetOutputsDim(kOutputs, output_dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// Why we need override GetExpectedKernelType?
// A cinn-graph may has no inpute var, if we use the base function,
// it will check wheter input tensors is initialized. Here we rewrite
// the function so that we can infer kernel type by output date type.
if (ctx.InputSize(kX)) {
// if the instruction has input, infer kernel type by input date type:
return OperatorWithKernel::GetExpectedKernelType(ctx);
}

// Else infer kernel type by output date type:
// The `OutputVar` will check wheter the kOutputs iff has one output var
const framework::Variable* var = ctx.OutputVar(kOutputs);
PADDLE_ENFORCE_NE(
var, nullptr,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op's Output Variable should not empty."));

const framework::Tensor* tensor = nullptr;
if (var->IsType<framework::Tensor>()) {
tensor = &var->Get<framework::Tensor>();
} else if (var->IsType<framework::LoDTensor>()) {
tensor = &var->Get<framework::LoDTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
tensor = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<framework::LoDTensorArray>()) {
auto t_arr = &var->Get<framework::LoDTensorArray>();
PADDLE_ENFORCE_EQ(t_arr->size(), 1UL,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op should just has One "
"Output when Input empty."));
tensor = &(t_arr->front());
}

PADDLE_ENFORCE_NE(
tensor, nullptr,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op's Output Tensor should not empty."));

VLOG(4) << "The tensor [" << ctx.OutputName(kOutputs) << "]'s dtype is "
<< paddle::framework::DataType2String(tensor->dtype());
auto output_type = paddle::framework::TransToProtoVarType(tensor->dtype());
return framework::OpKernelType(output_type, ctx.device_context());
}
};

class CinnInstructionRunOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/operators/cinn/cinn_launch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,12 @@ class CinnLaunchOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX),
"Input", string::format_string("%s|%s", kX, kNoNeedBufferX),
"CinnLaunchOp");
// The cinn-graph may hasn't input for CINN now support fill_constant,
// and its all inputs may generated by fill_constant instead of by fetch.
// OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX),
// "Input", string::format_string("%s|%s", kX,
// kNoNeedBufferX),
// "CinnLaunchOp");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnLaunchOp");
}
Expand Down

0 comments on commit 17b8335

Please sign in to comment.