Skip to content

Commit

Permalink
[common][pyTorch]Add zero_centered_gamma option to RMSNorm (NVIDIA#631)
Browse files Browse the repository at this point in the history
* Add zero_centered_gamma option to RMSNorm

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* Improving tests

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* More improvements to tests

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* Tweaking the tolerances

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* Fix LayerNormMLP test

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* Update transformer_engine/common/rmsnorm/rmsnorm_api.cpp

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/common/rmsnorm/rmsnorm_api.cpp

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* docs suggestions

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Tweak tolerances with bfloat16

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

---------

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 3, 2024
1 parent 5b155fb commit d68028c
Show file tree
Hide file tree
Showing 10 changed files with 312 additions and 91 deletions.
77 changes: 47 additions & 30 deletions tests/cpp/operator/test_rmsnorm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ void compute_ref_stats(const InputType *data, float *rsigma, const size_t N, con
template <typename InputType, typename OutputType>
void compute_ref_output(const InputType *data, const InputType *gamma, OutputType *output,
const float *rsigma, const size_t N, const size_t H, float *amax,
float scale) {
float scale, const bool zero_centered_gamma) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t tmp = current * rsigma[i] * static_cast<compute_t>(gamma[j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
compute_t tmp = current * rsigma[i] * g;
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
Expand All @@ -60,7 +64,7 @@ void compute_ref_output(const InputType *data, const InputType *gamma, OutputTyp
template <typename InputType, typename OutputType>
void compute_ref_backward(const OutputType *output_grad, const InputType *data, const float *rsigma,
const InputType *gamma, InputType *data_grad, InputType *gamma_grad,
const size_t N, const size_t H) {
const size_t N, const size_t H, const bool zero_centered_gamma) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);

Expand All @@ -70,7 +74,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = x * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
Expand All @@ -82,7 +89,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = x * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y);
Expand All @@ -97,7 +107,7 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
}

template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) {
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "RMSNorm kernel does not support OutputType > InputType";
return;
Expand Down Expand Up @@ -137,23 +147,25 @@ void performTest(const size_t N, const size_t H) {

// Forward kernel
float epsilon = 1e-5;
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data());
auto fwd_function = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd;
fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data());
fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data());

// Backward kernel
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data());
auto bwd_function = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd;
bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data());
bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data());

// Reference implementations
// use the GPU stats to tighten the tolerances
Expand All @@ -162,10 +174,11 @@ void performTest(const size_t N, const size_t H) {
compute_ref_stats(input.cpu_dptr<InputType>(), ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
compute_ref_output(input.cpu_dptr<InputType>(), gamma.cpu_dptr<WeightType>(), ref_output.get(),
rsigma.cpu_dptr<float>(), N, H, &ref_amax, ref_scale);
rsigma.cpu_dptr<float>(), N, H, &ref_amax, ref_scale,
zero_centered_gamma);
compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
rsigma.cpu_dptr<float>(), gamma.cpu_dptr<WeightType>(), ref_dx.get(),
ref_dgamma.get(), N, H);
ref_dgamma.get(), N, H, zero_centered_gamma);

cudaDeviceSynchronize();
auto err = cudaGetLastError();
Expand Down Expand Up @@ -197,9 +210,10 @@ std::vector<std::pair<size_t, size_t>> test_cases = {

} // namespace

class RMSNormTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
class RMSNormTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool>> {};

TEST_P(RMSNormTestSuite, TestRMSNorm) {
using namespace transformer_engine;
Expand All @@ -208,23 +222,26 @@ TEST_P(RMSNormTestSuite, TestRMSNorm) {
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
const bool zero_centered_gamma = std::get<3>(GetParam());

TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType, performTest<InputType, OutputType>(size.first, size.second);););
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma);););
}

INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite,
::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16,
DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16,
DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases)),
::testing::ValuesIn(test_cases),
::testing::Values(false, true)),
[](const testing::TestParamInfo<RMSNormTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
std::string name =
test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second) + "X" +
std::to_string(std::get<3>(info.param));
return name;
});
125 changes: 106 additions & 19 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
MultiheadAttention, RMSNorm, TransformerLayer
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
Expand Down Expand Up @@ -215,15 +215,41 @@ def forward(
return context_layer


class TorchLayerNorm(nn.Module):
def __init__(self, in_features: int,
eps: float,
zero_centered_gamma: bool):
super().__init__()
self.eps = eps
self.in_features = in_features
self.zero_centered_gamma = zero_centered_gamma

initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
self.weight = nn.Parameter(initial_value)
self.bias = nn.Parameter(torch.zeros(in_features))
self.register_parameter("weight", self.weight)
self.register_parameter("bias", self.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight if not self.zero_centered_gamma else 1 + self.weight
w = w.to(torch.float32)
b = self.bias.to(torch.float32)
inp = x.to(torch.float32)
out = torch.nn.functional.layer_norm(inp, (self.in_features,), weight=w,
bias=b, eps=self.eps)
return out.to(x.dtype)

# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
def __init__(self, in_features, eps=1e-5):
def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
super().__init__()

self.eps = eps
self.in_features = in_features
self.zero_centered_gamma = zero_centered_gamma

self.weight = nn.Parameter(torch.ones(in_features))
initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
self.weight = nn.Parameter(initial_value)
self.register_parameter("weight", self.weight)

def forward(self, x):
Expand All @@ -234,18 +260,24 @@ def forward(self, x):
r_rms_x = rms_x2 ** (-1. / 2)
x_normed = x * r_rms_x

return (self.weight.float() * x_normed).to(x.dtype)
w = self.weight.float()
if self.zero_centered_gamma:
w = 1 + w
return (w * x_normed).to(x.dtype)


class TorchLayerNormLinear(nn.Module):
def __init__(self, in_features: int, out_features: int,
eps: float, bias: bool = True,
normalization: str = "LayerNorm"):
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False):
super().__init__()
if normalization == "LayerNorm":
self.layernorm = nn.LayerNorm(in_features, eps=eps)
self.layernorm = TorchLayerNorm(in_features, eps=eps,
zero_centered_gamma=zero_centered_gamma)
elif normalization == "RMSNorm":
self.layernorm = TorchRMSNorm(in_features, eps=eps)
self.layernorm = TorchRMSNorm(in_features, eps=eps,
zero_centered_gamma=zero_centered_gamma)
else:
raise RuntimeError("Unsupported normalization")

Expand Down Expand Up @@ -299,9 +331,11 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int,
normalization: str = "LayerNorm"):
super().__init__()
if normalization == "LayerNorm":
self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.ln = TorchLayerNorm(hidden_size, eps=eps,
zero_centered_gamma=False)
elif normalization == "RMSNorm":
self.ln = TorchRMSNorm(hidden_size, eps=eps)
self.ln = TorchRMSNorm(hidden_size, eps=eps,
zero_centered_gamma=False)
else:
raise RuntimeError("Unsupported normalization")
if 'glu' in activation:
Expand Down Expand Up @@ -893,13 +927,15 @@ def test_linear_accuracy(dtype, bs, model):
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
def test_rmsnorm_accuracy(dtype, bs, model, eps):
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model]

te_rmsnorm = (
RMSNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype)
.cuda()
Expand All @@ -910,6 +946,7 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps):
TorchRMSNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype)
.cuda()
Expand All @@ -924,17 +961,64 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps):
torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)

# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 1e-7)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 2e-2)
atol = {torch.float32 : 1e-7,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model]

te_layernorm = (
LayerNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype)
.cuda()
.eval()
)

torch_layernorm = (
TorchLayerNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype)
.cuda()
.eval()
)

# Share params
with torch.no_grad():
torch_layernorm.weight = Parameter(te_layernorm.weight.clone())
torch_layernorm.bias = Parameter(te_layernorm.bias.clone())

te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)

# Check output.
atol = {torch.float32 : 1e-7,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
config = model_configs[model]

te_ln_linear = (
Expand All @@ -944,6 +1028,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
config.eps,
bias=True,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
Expand All @@ -957,6 +1042,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
config.eps,
bias=True,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
Expand All @@ -975,10 +1061,11 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)

# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
atol = {torch.float32 : 2e-4,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])


@pytest.mark.parametrize("dtype", param_types)
Expand Down
Loading

0 comments on commit d68028c

Please sign in to comment.