Skip to content

Commit

Permalink
add sin,cos,exp primitive operators
Browse files Browse the repository at this point in the history
  • Loading branch information
cxxly committed Jul 26, 2022
1 parent f9cd526 commit a375701
Show file tree
Hide file tree
Showing 10 changed files with 765 additions and 205 deletions.
71 changes: 71 additions & 0 deletions paddle/fluid/operators/prim_ops/cos_p_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {
class CosPrimOp : public framework::OperatorBase {
public:
CosPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator cos_p should not be excuted directly"));
}
};

class CosPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of cos_p op.");
AddOutput("Y", "(Tensor), The output tensor of cos_p op.");
AddComment(R"DOC(Autograd primitive cos_p operator.)DOC");
}
};

class CosPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};

class CosPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};

} // namespace operators
} // namespace paddle

REGISTER_OPERATOR(cos_p,
paddle::operators::CosPrimOp,
paddle::operators::CosPrimOpMaker,
paddle::operators::CosPrimOpShapeInference,
paddle::operators::CosPrimOpVarTypeInference);
71 changes: 71 additions & 0 deletions paddle/fluid/operators/prim_ops/exp_p_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {
class ExpPrimOp : public framework::OperatorBase {
public:
ExpPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator exp_p should not be excuted directly"));
}
};

class ExpPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of exp_p op.");
AddOutput("Y", "(Tensor), The output tensor of exp_p op.");
AddComment(R"DOC(Autograd primitive exp_p operator.)DOC");
}
};

class ExpPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};

class ExpPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};

} // namespace operators
} // namespace paddle

REGISTER_OPERATOR(exp_p,
paddle::operators::ExpPrimOp,
paddle::operators::ExpPrimOpMaker,
paddle::operators::ExpPrimOpShapeInference,
paddle::operators::ExpPrimOpVarTypeInference);
71 changes: 71 additions & 0 deletions paddle/fluid/operators/prim_ops/sin_p_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {
class SinPrimOp : public framework::OperatorBase {
public:
SinPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator sin_p should not be excuted directly"));
}
};

class SinPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of sin_p op.");
AddOutput("Y", "(Tensor), The output tensor of sin_p op.");
AddComment(R"DOC(Autograd primitive sin_p operator.)DOC");
}
};

class SinPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};

class SinPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};

} // namespace operators
} // namespace paddle

REGISTER_OPERATOR(sin_p,
paddle::operators::SinPrimOp,
paddle::operators::SinPrimOpMaker,
paddle::operators::SinPrimOpShapeInference,
paddle::operators::SinPrimOpVarTypeInference);
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,97 @@ def init_data(self):
]


class TestSinPJVPAndTranspose(TestAddPJVPAndTranspose):

def init_data(self):
# Set prim op
self.op_type = 'sin_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}

# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}

self.all_ops = [
# prim op:
'sin_p',
# jvp op:
'mul_p',
'cos_p',
# transpose op:
]


class TestCosPJVPAndTranspose(TestAddPJVPAndTranspose):

def init_data(self):
# Set prim op
self.op_type = 'cos_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}

# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}

self.all_ops = [
# prim op:
'cos_p',
# jvp op:
'mul_p',
'sin_p',
'fill_constant_p',
'sub_p'
# transpose op:
]


class TestExpPJVPAndTranspose(TestAddPJVPAndTranspose):

def init_data(self):
# Set prim op
self.op_type = 'exp_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}

# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}

self.all_ops = [
# prim op:
'exp_p',
# jvp op:
'mul_p',
# transpose op:
]


class TestReshapePJVPAndTranspose(TestAddPJVPAndTranspose):

def init_data(self):
Expand Down
60 changes: 60 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,66 @@ def init_data(self):
self.out_map = {0: self.output['Out']}


class TestSinOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'sin'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')

self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}

self.orig2prim_args = (X, )
self.all_ops = ['sin', 'sin_p']
self.out_map = {0: self.output['Out']}


class TestCosOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'cos'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')

self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}

self.orig2prim_args = (X, )
self.all_ops = ['cos', 'cos_p']
self.out_map = {0: self.output['Out']}


class TestExpOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'exp'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')

self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}

self.orig2prim_args = (X, )
self.all_ops = ['exp', 'exp_p']
self.out_map = {0: self.output['Out']}


class TestReshape2Orig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
Expand Down
Loading

0 comments on commit a375701

Please sign in to comment.