Skip to content

Commit

Permalink
Support both use_calc_stream and sync_op in send recv APIs (PaddlePad…
Browse files Browse the repository at this point in the history
  • Loading branch information
HermitSun authored and cxxly committed Sep 26, 2022
1 parent 159f10e commit 9c0ce9b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 6 deletions.
57 changes: 57 additions & 0 deletions python/paddle/incubate/autograd/primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,63 @@ def set_value(x, y, axis, starts, ends, strides, out):
return out


def mean(x, axis=None, keepdim=False):
axes = axis or tuple(range(0, len(x.shape)))
sum = reduce_sum(x, axis=axes, keepdim=keepdim)
norm = fill_const(shape=sum.shape,
value=functools.reduce(operator.mul,
[x.shape[axis] for axis in axes]),
dtype=sum.dtype)
return div(sum, norm)


def ones(shape, dtype):
return fill_const(1.0, shape, dtype)


def batch_norm(x,
axis,
gamma,
beta,
run_mean,
run_var,
eps=1e-5,
momentum=0.9,
use_run_stat=False):
"""batch normalizer.
Args:
x (Tensor): A tensor to be normalized.
axis (int): The features axis.
gamma (Tensor): The scale factor.
beta (float): The shift factor.
run_mean (Tensor): Running mean.
run_var (Tensor): Running variance.
eps (float, optional): A value added to the denominator for numerical
stability. Defaults to 1e-5.
momentum (float, optional): The value used for the running_mean and
running_var computation. Can be set to None for cumulative moving
average (i.e. simple average). Defaults to 0.9.
use_run_stat (bool, optional): Whether or not using runing statistics.
Defaults to False.
"""
reduce_axes = tuple(i for i in range(len(x.shape)) if i != axis)

if not use_run_stat:
m = mean(x, reduce_axes, keepdims=True)
v = mean(square(sub(x, m)), reduce_axes, keepdims=True)
x_hat = div(sub(x, m), sqrt(add(v, fill_const(eps, v.shape, v.dtype))))

momentum = fill_const(momentum, run_mean.shape, run_mean.dtype)
one = ones(run_mean.shape, run_mean.dtype)
run_mean = add(mul(momentum, run_mean), mul(sub(one, momentum), m))
run_var = add(mul(momentum, run_var), mul(sub(one, momentum), v))
else:
x_hat = div(sub(x, run_mean), sqrt(add(run_var, eps)))

return add(mul(gamma, x_hat), beta), run_mean, run_var


@REGISTER_FN('add_p', 'X', 'Y', 'Z')
def add(x, y, out=None):
return _simple_binop(LayerHelper('add_p', **locals()))
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/incubate/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,12 +436,12 @@ def reduce_sum_orig2prim(op, x):
def reduce_mean_orig2prim(op, x):
axes = tuple(range(0, len(
x.shape))) if op.attr('reduce_all') else op.attr('dim')
sum = reduce_sum(x, axis=axes, keepdim=op.attr('keep_dim'))
norm = fill_const(shape=sum.shape,
value=functools.reduce(operator.mul,
[x.shape[axis] for axis in axes]),
dtype=sum.dtype)
return div(sum, norm)
return primops.mean(x, axes, op.attr('keepdim'))


@REGISTER_ORIG2PRIM('batch_norm')
def batch_norm_orig2prim(op, x):
pass


@REGISTER_ORIG2PRIM('size')
Expand Down

0 comments on commit 9c0ce9b

Please sign in to comment.