diff --git a/oneflow/core/autograd/gradient_funcs/layer_norm.cpp b/oneflow/core/autograd/gradient_funcs/layer_norm.cpp index b4081334b74..99134794afc 100644 --- a/oneflow/core/autograd/gradient_funcs/layer_norm.cpp +++ b/oneflow/core/autograd/gradient_funcs/layer_norm.cpp @@ -107,13 +107,13 @@ Maybe LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple int64_t begin_norm_axis = ctx->begin_norm_axis; if (begin_norm_axis < 0) { begin_norm_axis += dy->shape()->NumAxes(); } - std::shared_ptr gamma = saved_tensors.at(ctx->gamma_index); if (!ctx->has_affine) { // Use LayerNormParamGrad(Tensor dy, Tensor gamma, Int64 begin_params_axis, Double epsilon). dy = JUST(functional::LayerNormParamGrad(dy, begin_params_axis, ctx->epsilon)); } else { // Use LayerNormAffineParamGrad(Tensor dy, Tensor gamma, Tensor normalized, Int64 // begin_params_axis, Double epsilon). + std::shared_ptr gamma = saved_tensors.at(ctx->gamma_index); std::shared_ptr normalized = saved_tensors.at(ctx->normalized_index); const auto& results = JUST(functional::LayerNormAffineParamGrad( dy, gamma, normalized, begin_params_axis, ctx->epsilon)); diff --git a/python/oneflow/test/modules/test_layernorm.py b/python/oneflow/test/modules/test_layernorm.py index 6022a07354f..07afe3e48a8 100644 --- a/python/oneflow/test/modules/test_layernorm.py +++ b/python/oneflow/test/modules/test_layernorm.py @@ -203,6 +203,24 @@ def get_random_norm_shape(): y = m(x) return y + @autotest(n=20, auto_backward=True, rtol=1e-3, atol=1e-3) + def test_layernorm_without_affine(test_case): + device = random_device() + channel = random(1, 200).to(int) + height = random(1, 2).to(int) + width = random(8192, 32768).to(int) + + def get_random_norm_shape(): + begin_axis = random(1, 3).to(int).value() + return tuple((channel.value(), height.value(), width.value())[begin_axis:]) + + m = torch.nn.LayerNorm(normalized_shape=get_random_norm_shape()).to(device) + x = random_pytorch_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to( + device + ) + y = m(x) + return y + if __name__ == "__main__": unittest.main()