Skip to content

Commit

Permalink
add gelu and erf primitive operators for new autograd (#45338)
Browse files Browse the repository at this point in the history
* add erf_p primitive operators

* add gelu orig2prim rule
  • Loading branch information
cxxly authored Sep 1, 2022
1 parent 1a0ef45 commit 4ed6f3b
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 21 deletions.
6 changes: 5 additions & 1 deletion paddle/fluid/operators/prim_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ set(PRIM_OP_SRCS
div_p_op.cc
sqrt_p_op.cc
tanh_p_op.cc
sin_p_op.cc
cos_p_op.cc
exp_p_op.cc
matmul_p_op.cc
fill_constant_p_op.cc
log_p_op.cc
select_p_op.cc
eq_p_op.cc
pow_p_op.cc
max_p_op.cc)
max_p_op.cc
erf_p_op.cc)

cc_test(
prim_op_test
Expand Down
78 changes: 78 additions & 0 deletions paddle/fluid/operators/prim_ops/erf_p_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace operators {
class ErfPrimOp : public framework::OperatorBase {
public:
ErfPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator erf_p should not be excuted directly"));
}
};

class ErfPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of erf_p op.");
AddOutput("Y", "(Tensor), The output tensor of erf_p op.");
AddComment(R"DOC(Autograd primitive erf_p operator.)DOC");
}
};

class ErfPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};

class ErfPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};

} // namespace operators
} // namespace paddle

REGISTER_OPERATOR(erf_p,
paddle::operators::ErfPrimOp,
paddle::operators::ErfPrimOpMaker,
paddle::operators::ErfPrimOpShapeInference,
paddle::operators::ErfPrimOpVarTypeInference);
20 changes: 20 additions & 0 deletions paddle/fluid/operators/prim_ops/prim_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ USE_OP_ITSELF(select_p);
USE_OP_ITSELF(eq_p);
USE_OP_ITSELF(pow_p);
USE_OP_ITSELF(max_p);
USE_OP_ITSELF(erf_p);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -710,5 +711,24 @@ TEST(PrimOp, max_p) {
ASSERT_EQ(shapes[2], 4L);
}

TEST(PrimOp, erf_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};

std::string x0 = "x0";
std::string x1 = "x1";

NewVar(block, x0, shape);
AppendOp(block, "erf_p", {{"X", {x0}}}, {{"Y", {x1}}}, {});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,42 @@ def init_data(self):
]


class TestErfPJVPAndTranspose(TestAddPJVPAndTranspose):

def init_data(self):
# Set prim op
self.op_type = 'erf_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}

# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}

self.all_ops = [
# prim op:
'erf_p',
# jvp op:
'exp_p',
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
'mul_p',
'mul_p',
'pow_p',
'sub_p',
# transpose op:
]


class TestLogPJVPAndTranspose(TestAddPJVPAndTranspose):

def init_data(self):
Expand Down
65 changes: 65 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,26 @@ def init_data(self):
self.out_map = {0: self.output['Out']}


class TestErfOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'erf'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')

self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}

self.orig2prim_args = (X, )
self.all_ops = ['erf', 'erf_p']
self.out_map = {0: self.output['Out']}


class TestLogOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
Expand Down Expand Up @@ -559,5 +579,50 @@ def init_data(self):
self.out_map = {0: self.output['Out']}


class TestGeluOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'gelu'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')

self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'approximate': False}

self.orig2prim_args = (X, )
self.all_ops = [
'gelu', 'add_p', 'erf_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'mul_p', 'mul_p', 'mul_p'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}


class TestGeluApproximateOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'gelu'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')

self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'approximate': True}

self.orig2prim_args = (X, )
self.all_ops = [
'add_p', 'add_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'fill_constant_p', 'fill_constant_p', 'gelu',
'mul_p', 'mul_p', 'mul_p', 'mul_p', 'pow_p', 'tanh_p'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}


if __name__ == '__main__':
unittest.main()
20 changes: 20 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,26 @@ def init_data(self):
self.out_map = {self.output['Y']: 0}


class TestErfPPrim2Orig(TestAddPPrim2Orig):

def init_data(self):
self.op_type = 'erf_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}

self.prim2orig_args = (X, )
self.all_ops = ['erf_p', 'erf']
self.out_map = {self.output['Y']: 0}


class TestLogPPrim2Orig(TestAddPPrim2Orig):

def init_data(self):
Expand Down
55 changes: 38 additions & 17 deletions python/paddle/fluid/tests/unittests/autograd/test_primapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
import unittest

import numpy as np
import autograd
import autograd.numpy as np_autograd

import paddle

import autograd
import autograd.numpy as anp
import autograd.scipy as ascipy
import config
import utils

Expand Down Expand Up @@ -278,6 +278,11 @@ def test_illegal_param(self):
np.array([1, 2, 3]),
np.array([2, 2, 2]),
), None, 'float32'),
('erf', paddle.erf, (np.random.rand(300, 288), ), None, 'float32'),
('gelu', paddle.nn.functional.gelu,
(np.random.rand(200, 189), ), None, 'float32'),
('gelu_approximate', lambda x: paddle.nn.functional.gelu(x, True),
(np.random.rand(200, 189), ), None, 'float32'),
))
class TestGrad(unittest.TestCase):

Expand Down Expand Up @@ -397,25 +402,41 @@ def multiply_pd(x):


multiply_ag = lambda xs: xs[0] * xs[0] * xs[0] * xs[0] * xs[0]
sin_ag = lambda xs: np_autograd.sin(xs[0])
cos_ag = lambda xs: np_autograd.cos(xs[0])
exp_ag = lambda xs: np_autograd.exp(xs[0])
sin_ag = lambda xs: anp.sin(xs[0])
cos_ag = lambda xs: anp.cos(xs[0])
exp_ag = lambda xs: anp.exp(xs[0])
pow_ag = lambda xs: xs[0]**xs[1]
log_ag = lambda xs: np_autograd.log(xs[0])
log_ag = lambda xs: anp.log(xs[0])
erf_ag = lambda xs: ascipy.special.erf(xs[0])


def gelu_ag(x, approximate=False):
if approximate:
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + anp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x**3))))
return x * cdf
else:
return x * (ascipy.special.erf(x / np.sqrt(2)) + 1) / 2


@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'fun_pd', 'fun_ag', 'xs', 'v', 'dtype'), (
('multiply', multiply_pd, multiply_ag,
(np.random.rand(3, 5), ), None, 'float32'),
('sin', paddle.sin, sin_ag, (np.random.rand(2, 3), ), None, 'float32'),
('cos', paddle.cos, cos_ag, (np.random.rand(3, 4), ), None, 'float32'),
('exp', paddle.exp, exp_ag, (np.random.rand(2, 3), ), None, 'float32'),
('pow', paddle.pow, pow_ag,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('log', paddle.log, log_ag, (np.random.rand(3, 8), ), None, 'float32'),
))
(utils.TEST_CASE_NAME, 'fun_pd', 'fun_ag', 'xs', 'v', 'dtype'),
(('multiply', multiply_pd, multiply_ag,
(np.random.rand(3, 5), ), None, 'float32'),
('sin', paddle.sin, sin_ag, (np.random.rand(2, 3), ), None, 'float32'),
('cos', paddle.cos, cos_ag, (np.random.rand(3, 4), ), None, 'float32'),
('exp', paddle.exp, exp_ag, (np.random.rand(2, 3), ), None, 'float32'),
('pow', paddle.pow, pow_ag,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('log', paddle.log, log_ag, (np.random.rand(3, 8), ), None, 'float32'),
('erf', paddle.erf, erf_ag, (np.random.rand(100, 200), ), None, 'float32'),
('gelu', paddle.nn.functional.gelu, lambda xs: gelu_ag(xs[0]),
(np.random.rand(10, 20, 30), ), None, 'float32'),
('gelu_approximate',
lambda x: paddle.nn.functional.gelu(x, approximate=True),
lambda xs: gelu_ag(xs[0], approximate=True),
(np.random.rand(10, 20, 30), ), None, 'float32')))
class TestGradWithHigherOrder(unittest.TestCase):

def setUp(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
('sin', primops.sin, randn(2, 3), {}, (2, 3), 'float64'),
('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'),
('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'),
('erf', primops.erf, randn(2, 3), {}, (2, 3), 'float64'),
('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'),
('reshape', primops.reshape, randn(2, 3), {
'shape': (3, 2)
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/incubate/autograd/primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,8 @@ def pow(x, y, out=None):
@REGISTER_FN('max_p', 'X', 'Y', 'Z')
def max(x, y, out=None):
return _simple_binop(LayerHelper('max_p', **locals()))


@REGISTER_FN('erf_p', 'X', 'Y')
def erf(x, out=None):
return _simple_unop(LayerHelper('erf_p', **locals()))
Loading

0 comments on commit 4ed6f3b

Please sign in to comment.