Skip to content

Commit

Permalink
[Bugfix] Fix layer norm for large input shape (apache#14870)
Browse files Browse the repository at this point in the history
* fix layer norm for large input shape

* try to fix

* use a larger eps

* try to fix test

* try to fix
  • Loading branch information
sxjscience authored and Rohit Kumar Srivastava committed May 14, 2019
1 parent 8356d9a commit 725cfe6
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 13 deletions.
8 changes: 4 additions & 4 deletions src/operator/nn/layer_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,14 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs,
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
reduce_workspace_size =
std::max(reduce_workspace_size,
broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_src_shape,
kAddTo, red_dst_shape));
broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_dst_shape,
kAddTo, red_src_shape));
});
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
reduce_workspace_size =
std::max(reduce_workspace_size,
broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_exclude_src_shape, kAddTo,
red_exclude_dst_shape));
broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_exclude_dst_shape, kAddTo,
red_exclude_src_shape));
});
});
workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
Expand Down
91 changes: 82 additions & 9 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3364,7 +3364,9 @@ def test_l2_normalization():
check_l2_normalization((nbatch, nchannel, height, width), mode, dtype)


def check_layer_normalization(in_shape, axis, eps, dtype=np.float32, forward_check_eps=1E-3):
def check_layer_normalization(in_shape, axis, eps, dtype=np.float32,
forward_check_eps=1E-3, backward_check_eps=1E-3,
npy_grad_check=True, finite_grad_check=True):
def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5):
if axis < 0:
axis += data.ndim
Expand All @@ -3377,6 +3379,24 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5):
np.reshape(beta, broadcast_shape)
return out

def npy_layer_norm_grad(data, gamma, out_grad, axis, eps):
if axis < 0:
axis += data.ndim
exclude_axis = tuple([ele for ele in range(data.ndim) if ele != axis])
data_mean = data.mean(axis=axis, keepdims=True)
data_var = data.var(axis=axis, keepdims=True)
data_std = np.sqrt(data_var + eps)
centered_data = (data - data_mean) / data_std
gamma_grad = (centered_data * out_grad).sum(axis=exclude_axis, keepdims=True)
beta_grad = out_grad.sum(axis=exclude_axis, keepdims=True)
w = out_grad * gamma.reshape([1 if i != axis else data.shape[axis] for i in range(data.ndim)])\
/ data_std
data_grad = w - w.mean(axis=axis, keepdims=True)\
- centered_data * (w * centered_data).mean(axis=axis, keepdims=True)
gamma_grad = gamma_grad.reshape((-1,))
beta_grad = beta_grad.reshape((-1,))
return data_grad, gamma_grad, beta_grad

ctx = default_context()
data = np.random.normal(0, 1, in_shape).astype(dtype)
gamma = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype)
Expand All @@ -3392,10 +3412,51 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5):
out_nd = exe.forward()[0]
out = npy_layer_norm(data, gamma, beta, axis, eps)
assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps, forward_check_eps)
for req in ['write', 'add']:
check_numeric_gradient(out_s, {'data': data, 'gamma': gamma, 'beta': beta},
grad_nodes={'data': req, 'gamma': req, 'beta': req},
numeric_eps=1e-2, rtol=1e-2, atol=1e-2)

if finite_grad_check:
for req in ['write', 'add']:
check_numeric_gradient(out_s, {'data': data, 'gamma': gamma, 'beta': beta},
grad_nodes={'data': req, 'gamma': req, 'beta': req},
numeric_eps=1e-2, rtol=1e-2, atol=1e-2)

if npy_grad_check:
# Test for grad_req = write
out_grad = np.random.normal(0, 1, in_shape).astype(dtype)
exe = out_s.simple_bind(ctx, data=in_shape, grad_req='write')
exe.arg_dict['data'][:] = data
exe.arg_dict['gamma'][:] = gamma
exe.arg_dict['beta'][:] = beta
exe.forward()
exe.backward([mx.nd.array(out_grad, ctx=ctx)])
gt_data_grad, gt_gamma_grad, gt_beta_grad =\
npy_layer_norm_grad(data, gamma, out_grad, axis, eps)
assert_almost_equal(exe.grad_dict['data'].asnumpy(), gt_data_grad, backward_check_eps, backward_check_eps)
assert_almost_equal(exe.grad_dict['gamma'].asnumpy(), gt_gamma_grad, backward_check_eps, backward_check_eps)
assert_almost_equal(exe.grad_dict['beta'].asnumpy(), gt_beta_grad, backward_check_eps, backward_check_eps)

# Test for grad_req = add
out_grad = np.random.normal(0, 1, in_shape).astype(dtype)
init_data_grad = np.random.normal(0, 1, in_shape).astype(dtype)
init_gamma_grad = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype)
init_beta_grad = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype)
exe = out_s.simple_bind(ctx, data=in_shape, grad_req='add')
exe.arg_dict['data'][:] = data
exe.arg_dict['gamma'][:] = gamma
exe.arg_dict['beta'][:] = beta
exe.grad_dict['data'][:] = init_data_grad
exe.grad_dict['gamma'][:] = init_gamma_grad
exe.grad_dict['beta'][:] = init_beta_grad
exe.forward()
exe.backward([mx.nd.array(out_grad, ctx=ctx)])
gt_data_grad, gt_gamma_grad, gt_beta_grad = \
npy_layer_norm_grad(data, gamma, out_grad, axis, eps)
assert_almost_equal(exe.grad_dict['data'].asnumpy(),
gt_data_grad + init_data_grad, backward_check_eps, backward_check_eps)
assert_almost_equal(exe.grad_dict['gamma'].asnumpy(),
gt_gamma_grad + init_gamma_grad, backward_check_eps, backward_check_eps)
assert_almost_equal(exe.grad_dict['beta'].asnumpy(),
gt_beta_grad + init_beta_grad, backward_check_eps, backward_check_eps)


@with_seed()
def test_norm():
Expand Down Expand Up @@ -3469,13 +3530,25 @@ def l2norm(input_data, axis=0, keepdims=True):


def test_layer_norm():
for dtype, forward_check_eps in zip([np.float16, np.float32, np.float64],
[1E-2, 1E-3, 1E-4]):
for in_shape in [(10, 6, 5), (10, 10)]:
for dtype, forward_check_eps, backward_check_eps in zip([np.float16, np.float32, np.float64],
[1E-2, 1E-3, 1E-4],
[1E-2, 1E-3, 1E-4]):
if dtype != np.float16:
in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10), (128 * 32, 512)], [True, True, False]
else:
in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10)], [True, True] # large input + fp16 does not pass the forward check
for in_shape, finite_grad_check in zip(in_shape_l, finite_grad_check_l):
for axis in range(-len(in_shape), len(in_shape)):
for eps in [1E-2, 1E-3]:
if dtype == np.float16:
npy_grad_check = False
else:
npy_grad_check = True
check_layer_normalization(in_shape, axis, eps, dtype=dtype,
forward_check_eps=forward_check_eps)
forward_check_eps=forward_check_eps,
backward_check_eps=backward_check_eps,
npy_grad_check=npy_grad_check,
finite_grad_check=finite_grad_check)


# Numpy Implementation of Sequence Ops
Expand Down

0 comments on commit 725cfe6

Please sign in to comment.