Skip to content

Commit

Permalink
add ge_p abs_p primitive oparators
Browse files Browse the repository at this point in the history
  • Loading branch information
cxxly committed Sep 9, 2022
1 parent 44f6771 commit 8db23c1
Show file tree
Hide file tree
Showing 18 changed files with 716 additions and 43 deletions.
4 changes: 3 additions & 1 deletion paddle/fluid/operators/prim_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ set(PRIM_OP_SRCS
select_p_op.cc
eq_p_op.cc
gt_p_op.cc
ge_p_op.cc
ne_p_op.cc
pow_p_op.cc
max_p_op.cc
erf_p_op.cc)
erf_p_op.cc
abs_p_op.cc)

cc_test(
prim_op_test
Expand Down
71 changes: 71 additions & 0 deletions paddle/fluid/operators/prim_ops/abs_p_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {
class AbsPrimOp : public framework::OperatorBase {
public:
AbsPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator abs_p should not be excuted directly"));
}
};

class AbsPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of abs_p op.");
AddOutput("Y", "(Tensor), The output tensor of abs_p op.");
AddComment(R"DOC(Autograd primitive abs_p operator.)DOC");
}
};

class AbsPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};

class AbsPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};

} // namespace operators
} // namespace paddle

REGISTER_OPERATOR(abs_p,
paddle::operators::AbsPrimOp,
paddle::operators::AbsPrimOpMaker,
paddle::operators::AbsPrimOpShapeInference,
paddle::operators::AbsPrimOpVarTypeInference);
119 changes: 119 additions & 0 deletions paddle/fluid/operators/prim_ops/ge_p_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {
class GePrimOp : public framework::OperatorBase {
public:
GePrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator ge_p should not be excuted directly"));
}
};

class GePrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of ge_p op.");
AddInput("Y", "(Tensor), The input tensor of ge_p op.");
AddOutput("Z", "(Tensor), The output tensor of ge_p op.");
AddComment(R"DOC(
Autograd primitive ge_p operator.
)DOC");
}
};

class GePrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];

framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}

PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};

class GePrimOpVarTypeInference : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
PADDLE_ENFORCE_EQ(x_dtype,
y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype,
y_dtype));

SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, framework::proto::VarType::BOOL);
}
};

} // namespace operators
} // namespace paddle

REGISTER_OPERATOR(ge_p,
paddle::operators::GePrimOp,
paddle::operators::GePrimOpMaker,
paddle::operators::GePrimOpShapeInference,
paddle::operators::GePrimOpVarTypeInference);
119 changes: 119 additions & 0 deletions paddle/fluid/operators/prim_ops/gt_p_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {
class GtPrimOp : public framework::OperatorBase {
public:
GtPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator gt_p should not be excuted directly"));
}
};

class GtPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of gt_p op.");
AddInput("Y", "(Tensor), The input tensor of gt_p op.");
AddOutput("Z", "(Tensor), The output tensor of gt_p op.");
AddComment(R"DOC(
Autograd primitive gt_p operator.
)DOC");
}
};

class GtPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];

framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}

PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};

class GtPrimOpVarTypeInference : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
PADDLE_ENFORCE_EQ(x_dtype,
y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype,
y_dtype));

SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, framework::proto::VarType::BOOL);
}
};

} // namespace operators
} // namespace paddle

REGISTER_OPERATOR(gt_p,
paddle::operators::GtPrimOp,
paddle::operators::GtPrimOpMaker,
paddle::operators::GtPrimOpShapeInference,
paddle::operators::GtPrimOpVarTypeInference);
Loading

0 comments on commit 8db23c1

Please sign in to comment.