Skip to content

Commit

Permalink
add batch_norm orig2prim transform rule (#46446)
Browse files Browse the repository at this point in the history
* Support both use_calc_stream and sync_op in send recv APIs (#46023)

* add batch_norm prim2orig rule

Co-authored-by: Wen Sun <35923278+HermitSun@users.noreply.github.com>
  • Loading branch information
cxxly and HermitSun authored Sep 26, 2022
1 parent 23c5064 commit b0ec8ef
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 8 deletions.
76 changes: 76 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,5 +900,81 @@ def init_data(self):
self.out_map = {0: self.output['Out']}


class TestBatchnormOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'batch_norm'
x = paddle.static.data(name='X', shape=[5, 8], dtype='float')
m = paddle.static.data(name='Mean', shape=[8], dtype='float')
v = paddle.static.data(name='Variance', shape=[8], dtype='float')
w = paddle.static.data(name='Scale', shape=[8], dtype='float')
b = paddle.static.data(name='Bias', shape=[8], dtype='float')

self.input = {
"X": [x],
"Scale": [w],
"Bias": [b],
"Mean": [m],
"Variance": [v]
}
saved_variance = self.layer_help.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
batch_norm_out = self.layer_help.create_variable_for_type_inference(
x.dtype)
saved_mean = self.layer_help.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
self.output = {
"Y": [batch_norm_out],
"MeanOut": [m],
"VarianceOut": [v],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance]
}

self.attrs = {
"momentum": 0.9,
"epsilon": 1e-5,
"is_test": False,
"data_layout": 'NCHW',
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": False,
"trainable_statistics": False,
}
self.orig2prim_args = (b, m, None, w, v, x)
self.all_ops = [
'add_p', 'add_p', 'add_p', 'add_p', 'batch_norm', 'broadcast_p',
'broadcast_p', 'broadcast_p', 'broadcast_p', 'broadcast_p', 'div_p',
'div_p', 'div_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'mul_p', 'mul_p', 'mul_p', 'mul_p', 'mul_p',
'pow_p', 'reduce_sum_p', 'reduce_sum_p', 'reshape_p', 'reshape_p',
'reshape_p', 'reshape_p', 'sqrt_p', 'sub_p', 'sub_p', 'sub_p',
'sub_p'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {}


class TestFillConstantOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'fill_constant'

self.attrs = {'value': 1., 'shape': (2, 3), 'dtype': paddle.float32}
self.input = {}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(
dtype=paddle.float32)
}

self.orig2prim_args = (None, None, None)
self.all_ops = ['fill_constant', 'fill_constant_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}


if __name__ == '__main__':
unittest.main()
34 changes: 33 additions & 1 deletion python/paddle/fluid/tests/unittests/autograd/test_primapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,39 @@ def test_illegal_param(self):
lambda x: paddle.var(x, axis=1, unbiased=False),
(np.random.rand(10, 20, 30), ), None, 'float32'),
('var_with_keepdim', lambda x: paddle.var(x, axis=1, keepdim=True),
(np.random.rand(10, 20, 30), ), None, 'float32')))
(np.random.rand(10, 20, 30), ), None, 'float32'),
('bn', lambda x, w, b: paddle.nn.functional.batch_norm(
x, paddle.ones((10, )), paddle.ones(
(10, )), w, b), (np.random.rand(10, 10), np.random.rand(10),
np.random.rand(10)), None, 'float32'),
('bn_train', lambda x, w, b: paddle.nn.functional.batch_norm(
x, paddle.ones((10, )), paddle.ones((10, )), w, b, training=True),
(np.random.rand(
10, 10), np.random.rand(10), np.random.rand(10)), None, 'float32'),
('bn_nhwc', lambda x, w, b: paddle.nn.functional.batch_norm(
x,
paddle.ones((10, )) + 1,
paddle.ones((10, )),
w,
b,
training=True,
data_format='NHWC',
), (np.random.rand(
10, 10), np.random.rand(10), np.random.rand(10)), None, 'float32'),
('bn_global_stat',
lambda x, w, b: paddle.nn.functional.batch_norm(x,
paddle.ones(
(10, )) + 3.2,
paddle.ones(
(10, )) + 6.7,
w,
b,
training=True,
data_format='NHWC',
use_global_stats=True),
(np.random.rand(
10, 10), np.random.rand(10), np.random.rand(10)), None, 'float32'),
))
class TestGrad(unittest.TestCase):

def setUp(self):
Expand Down
94 changes: 94 additions & 0 deletions python/paddle/incubate/autograd/primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import operator

import paddle
from paddle.fluid.layer_helper import LayerHelper

Expand Down Expand Up @@ -92,6 +95,97 @@ def set_value(x, y, axis, starts, ends, strides, out):
return out


def mean(x, axis=None, keepdim=False):
axes = axis or tuple(range(0, len(x.shape)))
sum = reduce_sum(x, axis=axes, keepdim=keepdim)
norm = fill_const(shape=sum.shape,
value=functools.reduce(operator.mul,
[x.shape[axis] for axis in axes]),
dtype=sum.dtype)
return div(sum, norm)


def ones(shape, dtype):
return fill_const(1, shape, dtype)


def zeros(shape, dtype):
return fill_const(0, shape, dtype)


def batch_norm(x,
axis,
gamma,
beta,
run_mean,
run_var,
eps=1e-5,
momentum=0.9,
use_run_stat=False,
reserve_space=None):
"""batch normalizer.
Args:
x (Tensor): A tensor to be normalized.
axis (int): The features axis.
gamma (Tensor): The scale factor.
beta (float): The shift factor.
run_mean (Tensor): Running mean.
run_var (Tensor): Running variance.
eps (float, optional): A value added to the denominator for numerical
stability. Defaults to 1e-5.
momentum (float, optional): The value used for the running_mean and
running_var computation. Can be set to None for cumulative moving
average (i.e. simple average). Defaults to 0.9.
use_run_stat (bool, optional): Whether or not using runing statistics.
Defaults to False.
"""
reduce_axes = tuple(i for i in range(len(x.shape)) if i != axis)
stats_shape = tuple(1 if i in reduce_axes else s
for i, s in enumerate(x.shape))

batch_mean = zeros(run_mean.shape, run_mean.dtype)
batch_var = zeros(run_var.shape, run_var.dtype)

if not use_run_stat:
batch_mean = mean(x, reduce_axes, keepdim=True)
batch_var = mean(square(sub(x, broadcast(batch_mean, x.shape))),
reduce_axes,
keepdim=True)
x_hat = div(
sub(x, broadcast(batch_mean, x.shape)),
sqrt(
add(broadcast(batch_var, x.shape),
fill_const(eps, x.shape, batch_var.dtype))))

momentum = fill_const(momentum, run_mean.shape, run_mean.dtype)
run_mean = add(
mul(momentum, run_mean),
mul(sub(ones(run_mean.shape, run_mean.dtype), momentum),
reshape(batch_mean, run_mean.shape)))
run_var = add(
mul(momentum, run_var),
mul(sub(ones(run_var.shape, run_var.dtype), momentum),
reshape(batch_var, run_var.shape)))
else:
x_hat = div(
sub(x, broadcast(reshape(run_mean, stats_shape), x.shape)),
sqrt(
add(broadcast(reshape(run_var, stats_shape), x.shape),
fill_const(eps, x.shape, x.dtype))))
y = add(mul(broadcast(reshape(gamma, stats_shape), x_hat.shape), x_hat),
broadcast(reshape(beta, stats_shape), x_hat.shape))

if reserve_space:
return run_mean, reserve_space, batch_mean, batch_var, run_var, y
else:
return run_mean, batch_mean, batch_var, run_var, y


def square(x):
return pow(x, fill_const(2., x.shape, x.dtype))


@REGISTER_FN('add_p', 'X', 'Y', 'Z')
def add(x, y, out=None):
return _simple_binop(LayerHelper('add_p', **locals()))
Expand Down
51 changes: 44 additions & 7 deletions python/paddle/incubate/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,20 @@ def fill_any_like_orig2prim(op, x):
convert_dtype(INT_DTYPE_2_STRING[op.attr('dtype')])))


@REGISTER_ORIG2PRIM('fill_constant')
def fill_const_orig2prim(op,
shape_tensor=None,
shape_tensor_list=None,
value_tensor=None):
if shape_tensor or shape_tensor_list or value_tensor:
raise TypeError(
'fill_const_orig2prim currently not support Tensor input of shape and value.'
)
return fill_const(value=op.attr('value'),
shape=op.attr('shape'),
dtype=paddle.dtype(op.attr('dtype')))


@REGISTER_ORIG2PRIM('sum')
def sum_orig2prim(op, xs):
x0 = xs[0]
Expand Down Expand Up @@ -391,7 +405,7 @@ def pow_orig2prim(op, x, y):

@REGISTER_ORIG2PRIM('square')
def square_orig2prim(op, x):
return primops.pow(x, fill_const(2., x.shape, x.dtype))
return primops.square(x)


@REGISTER_ORIG2PRIM('elementwise_max')
Expand Down Expand Up @@ -436,12 +450,35 @@ def reduce_sum_orig2prim(op, x):
def reduce_mean_orig2prim(op, x):
axes = tuple(range(0, len(
x.shape))) if op.attr('reduce_all') else op.attr('dim')
sum = reduce_sum(x, axis=axes, keepdim=op.attr('keep_dim'))
norm = fill_const(shape=sum.shape,
value=functools.reduce(operator.mul,
[x.shape[axis] for axis in axes]),
dtype=sum.dtype)
return div(sum, norm)
return primops.mean(x, axes, op.attr('keep_dim'))


@REGISTER_ORIG2PRIM('batch_norm')
def batch_norm_orig2prim(op, bias, run_mean, momentum_tensor, scale, run_var,
x):
momentum = op.attr('momentum')
eps = op.attr('epsilon')
is_test = op.attr('is_test')
data_layout = op.attr('data_layout')
use_global_stats = op.attr('use_global_stats')
trainable_statistics = op.attr('trainable_statistics')
reserve_space = None if len(
op.output_names) == 5 else get_output_var_list(op)[1]

feature_axis = 1 if data_layout in ('NC', 'NCL', 'NCHW',
'NCHWD') else len(x.shape) - 1
use_run_stat = (is_test and (not trainable_statistics)) or use_global_stats

return primops.batch_norm(x,
feature_axis,
scale,
bias,
run_mean,
run_var,
eps=eps,
momentum=momentum,
use_run_stat=use_run_stat,
reserve_space=reserve_space)


@REGISTER_ORIG2PRIM('size')
Expand Down

0 comments on commit b0ec8ef

Please sign in to comment.