From e0baf1c91ed6c5dc76227cc3a9abd8288c1b6951 Mon Sep 17 00:00:00 2001 From: Lucas Santos Date: Tue, 1 Jul 2025 19:10:43 +0000 Subject: [PATCH 1/3] Refactor RMSNorm Triton tests --- op_tests/triton_tests/test_rmsnorm.py | 250 ++++++++++++++++- op_tests/triton_tests/test_rmsnorm_quant.py | 282 -------------------- 2 files changed, 236 insertions(+), 296 deletions(-) delete mode 100644 op_tests/triton_tests/test_rmsnorm_quant.py diff --git a/op_tests/triton_tests/test_rmsnorm.py b/op_tests/triton_tests/test_rmsnorm.py index df54e8f241..2c20ce4e61 100644 --- a/op_tests/triton_tests/test_rmsnorm.py +++ b/op_tests/triton_tests/test_rmsnorm.py @@ -4,9 +4,14 @@ import pytest import torch import triton +import aiter from aiter.ops.triton.rmsnorm import ( rms_norm, rmsnorm2d_fwd_with_add, + rmsnorm2d_fwd_with_smoothquant, + rmsnorm2d_fwd_with_dynamicquant, + rmsnorm2d_fwd_with_add_smoothquant, + rmsnorm2d_fwd_with_add_dynamicquant, ) @@ -29,25 +34,66 @@ def torch_rmsnorm(x, g, out_dtype=torch.float16, epsilon=1e-6): return rms_norm -def run_torch(x, weight, eps, residual=None): +def run_torch(input, weight, eps, residual=None, x_scale=None, y_scale_dtype=None): if residual is None: residual_out = None - output = torch_rmsnorm(x, weight, x.dtype, eps) + output = torch_rmsnorm(input, weight, input.dtype, eps) else: - residual_out = x + residual + residual_out = input + residual output = torch_rmsnorm(residual_out, weight, residual_out.dtype, eps) - return output, residual_out - - -def run_triton(x, weight, eps, residual=None): - if residual is None: - residual_out = None - output = rms_norm(x, weight, eps) + if y_scale_dtype is None: + y_scale = None + output_q = output + else: + output_q, y_scale = aiter.pertoken_quant(output, x_scale=x_scale) + return output_q, residual_out, y_scale, output + + +def run_triton(input, weight, eps, residual=None, x_scale=None, y_scale_dtype=None): + # out_before_quant = None + if y_scale_dtype is None: + y_scale = None + if residual is None: + residual_out = None + output = rms_norm(input, weight, eps) + else: + residual_out = torch.empty_like(input) + output = torch.empty_like(input) + output = rmsnorm2d_fwd_with_add( + output, input, residual, residual_out, weight, eps + ) + elif x_scale is None: + y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") + output = torch.empty(input.shape, dtype=torch.int8, device="cuda") + if residual is None: + residual_out = None + rmsnorm2d_fwd_with_dynamicquant(output, input, y_scale, weight, eps) + elif residual is not None: + residual_out = torch.empty_like(input) + rmsnorm2d_fwd_with_add_dynamicquant( + output, input, residual, residual_out, y_scale, weight, eps + ) else: - residual_out = torch.empty_like(x) - output = torch.empty_like(x) - output = rmsnorm2d_fwd_with_add(output, x, residual, residual_out, weight, eps) - return output, residual_out + y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") + output = torch.empty(input.shape, dtype=torch.int8, device="cuda") + if residual is None: + residual_out = None + rmsnorm2d_fwd_with_smoothquant(output, input, x_scale, y_scale, weight, eps) + else: + residual_out = torch.empty_like(input) + # out_before_quant = torch.empty_like(input) + rmsnorm2d_fwd_with_add_smoothquant( + output, + input, + residual, + residual_out, + x_scale, + y_scale, + weight, + eps, + # out_before_quant=out_before_quant, + ) + return output, residual_out, y_scale # , out_before_quant def get_vals(): @@ -180,3 +226,179 @@ def test_fused_add_rmsnorm(M, N, in_dtype_str): triton.testing.assert_close(res_triton, res_torch, atol=atol, rtol=rtol) triton.testing.assert_close(dx_triton, dx_torch, rtol=rtol, atol=atol) triton.testing.assert_close(dg_triton, dg_torch, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) +@pytest.mark.parametrize( + "M, N", + [(shape) for shape in get_vals()], +) +def test_rmsnorm_smoothquant(M, N, in_dtype_str, scale_dtype_str): + arg_to_torch_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + in_dtype = arg_to_torch_dtype[in_dtype_str] + scale_dtype = arg_to_torch_dtype[scale_dtype_str] + + torch.manual_seed(0) + + x = torch.randn(M, N, device="cuda", dtype=in_dtype) + weight = torch.randn(N, device="cuda", dtype=in_dtype) + x_scale = torch.randn(N, device="cuda", dtype=scale_dtype) + + (y_torch, _, yscale_torch, *_) = run_torch( + x, weight, 1e-5, x_scale=x_scale, y_scale_dtype=scale_dtype + ) + (y_triton, _, yscale_triton, *_) = run_triton( + x, weight, 1e-5, x_scale=x_scale, y_scale_dtype=scale_dtype + ) + + triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) + triton.testing.assert_close(yscale_triton, yscale_torch, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) +@pytest.mark.parametrize( + "M, N", + [(shape) for shape in get_vals()], +) +def test_rmsnorm_dynamicquant(M, N, in_dtype_str, scale_dtype_str): + arg_to_torch_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + in_dtype = arg_to_torch_dtype[in_dtype_str] + scale_dtype = arg_to_torch_dtype[scale_dtype_str] + + torch.manual_seed(0) + + x = torch.randn(M, N, device="cuda", dtype=in_dtype) + weight = torch.randn(N, device="cuda", dtype=in_dtype) + + (y_torch, _, yscale_torch, *_) = run_torch( + x, weight, 1e-5, y_scale_dtype=scale_dtype + ) + (y_triton, _, yscale_triton, *_) = run_triton( + x, weight, 1e-5, y_scale_dtype=scale_dtype + ) + + triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) + triton.testing.assert_close(yscale_triton, yscale_torch, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) +@pytest.mark.parametrize( + "M, N", + [(shape) for shape in get_vals()], +) +def test_rmsnorm_fused_add_smoothquant(M, N, in_dtype_str, scale_dtype_str): + arg_to_torch_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + in_dtype = arg_to_torch_dtype[in_dtype_str] + scale_dtype = arg_to_torch_dtype[scale_dtype_str] + + torch.manual_seed(0) + + x = torch.randn(M, N, device="cuda", dtype=in_dtype) + weight = torch.randn(N, device="cuda", dtype=in_dtype) + res = torch.randn(M, N, device="cuda", dtype=in_dtype) + x_scale = torch.randn(N, device="cuda", dtype=scale_dtype) + + (y_torch, res_torch, yscale_torch, *_) = run_torch( + x, weight, 1e-5, residual=res, x_scale=x_scale, y_scale_dtype=scale_dtype + ) + (y_triton, res_triton, yscale_triton, *_) = run_triton( + x, weight, 1e-5, residual=res, x_scale=x_scale, y_scale_dtype=scale_dtype + ) + + triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) + triton.testing.assert_close(res_triton, res_torch, atol=1e-3, rtol=1e-3) + triton.testing.assert_close(yscale_triton, yscale_torch, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) +@pytest.mark.parametrize( + "M, N", + [(shape) for shape in get_vals()], +) +def test_rmsnorm_fused_add_dynamicquant(M, N, in_dtype_str, scale_dtype_str): + arg_to_torch_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + in_dtype = arg_to_torch_dtype[in_dtype_str] + scale_dtype = arg_to_torch_dtype[scale_dtype_str] + + torch.manual_seed(0) + + x = torch.randn(M, N, device="cuda", dtype=in_dtype) + weight = torch.randn(N, device="cuda", dtype=in_dtype) + res = torch.randn(M, N, device="cuda", dtype=in_dtype) + + (y_torch, res_torch, yscale_torch, *_) = run_torch( + x, weight, 1e-5, residual=res, y_scale_dtype=scale_dtype + ) + (y_triton, res_triton, yscale_triton, *_) = run_triton( + x, weight, 1e-5, residual=res, y_scale_dtype=scale_dtype + ) + + triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) + triton.testing.assert_close(res_triton, res_torch, atol=1e-3, rtol=1e-3) + triton.testing.assert_close(yscale_triton, yscale_torch, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("B", [1, 4, 8]) +@pytest.mark.parametrize("T", [128, 512, 2048]) +@pytest.mark.parametrize("D", [64, 4096]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_rms_norm_dynamic_per_token_fp8_quant( + B: int, T: int, D: int, dtype: torch.dtype +) -> None: + B_T = B * T + # Use integers to ensure consistent results across layouts, + # avoiding discrepancies in floating-point reductions with varying data layouts + x = torch.floor(torch.distributions.Uniform(-3, 3).sample((B_T, D))).to( + dtype=dtype, device="cuda" + ) + w = torch.floor(torch.distributions.Uniform(-3, 3).sample((D,))).to( + dtype=dtype, device="cuda" + ) + + EPS = 1e-6 + quant_dtype = torch.float8_e4m3fnuz + + xq_fused_triton = torch.empty(x.shape, dtype=quant_dtype, device="cuda") + x_scale_fused = torch.empty(x.shape[0], 1, dtype=torch.float32, device="cuda") + + x_normed = rmsnorm2d_fwd_with_dynamicquant( + xq_fused_triton, x, x_scale_fused, w, EPS, dump_rms_norm=True + ) + + ref_x_normed = torch_rmsnorm(x, w, dtype, EPS) + ref_xq, ref_x_scale = aiter.pertoken_quant(ref_x_normed, quant_dtype=quant_dtype) + + xq_dequant = xq_fused_triton.to(torch.float32) * x_scale_fused + xq_dequant = xq_dequant.to(dtype) + ref_xq_dequant = ref_xq.to(torch.float32) * ref_x_scale + ref_xq_dequant = xq_dequant.to(dtype) + + if dtype == torch.float32: + atol = 1e-5 + rtol = 1e-5 + else: + atol = 1e-2 + rtol = 1e-2 + + torch.testing.assert_close(xq_dequant, ref_xq_dequant, atol=atol, rtol=rtol) + torch.testing.assert_close(x_normed, ref_x_normed, atol=atol, rtol=rtol) diff --git a/op_tests/triton_tests/test_rmsnorm_quant.py b/op_tests/triton_tests/test_rmsnorm_quant.py deleted file mode 100644 index 67d4a9e987..0000000000 --- a/op_tests/triton_tests/test_rmsnorm_quant.py +++ /dev/null @@ -1,282 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -import torch -import pytest -import triton -import aiter -from aiter.ops.triton.rmsnorm import ( - rms_norm, - rmsnorm2d_fwd_with_add, - rmsnorm2d_fwd_with_smoothquant, - rmsnorm2d_fwd_with_dynamicquant, - rmsnorm2d_fwd_with_add_smoothquant, - rmsnorm2d_fwd_with_add_dynamicquant, -) - - -def torch_rmsnorm(x, g, out_dtype=torch.float16, epsilon=1e-6): - M, N = x.shape - # cast to float32 as the triton kernel - x_f32 = x.float() - g_f32 = g.float() - rms = torch.sqrt(torch.sum(x_f32 * x_f32, dim=-1) * 1 / N) - rsigma = 1.0 / rms - rms_norm_f32 = x_f32 * rsigma.unsqueeze(1) * g_f32 - rms_norm = rms_norm_f32.to(out_dtype) - return rms_norm - - -def run_torch(input, weight, eps, residual=None, x_scale=None, y_scale_dtype=None): - if residual is None: - residual_out = None - output = torch_rmsnorm(input, weight, input.dtype, eps) - else: - residual_out = input + residual - output = torch_rmsnorm(residual_out, weight, residual_out.dtype, eps) - if y_scale_dtype is None: - y_scale = None - output_q = output - else: - output_q, y_scale = aiter.pertoken_quant(output, x_scale=x_scale) - return output_q, residual_out, y_scale, output - - -def run_triton(input, weight, eps, residual=None, x_scale=None, y_scale_dtype=None): - # out_before_quant = None - if y_scale_dtype is None: - y_scale = None - if residual is None: - residual_out = None - output = rms_norm(input, weight, eps) - else: - residual_out = torch.empty_like(input) - output = torch.empty_like(input) - rmsnorm2d_fwd_with_add(output, input, residual, residual_out, weight, eps) - elif x_scale is None: - y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") - output = torch.empty(input.shape, dtype=torch.int8, device="cuda") - if residual is None: - residual_out = None - rmsnorm2d_fwd_with_dynamicquant(output, input, y_scale, weight, eps) - elif residual is not None: - residual_out = torch.empty_like(input) - rmsnorm2d_fwd_with_add_dynamicquant( - output, input, residual, residual_out, y_scale, weight, eps - ) - else: - y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") - output = torch.empty(input.shape, dtype=torch.int8, device="cuda") - if residual is None: - residual_out = None - rmsnorm2d_fwd_with_smoothquant(output, input, x_scale, y_scale, weight, eps) - else: - residual_out = torch.empty_like(input) - # out_before_quant = torch.empty_like(input) - rmsnorm2d_fwd_with_add_smoothquant( - output, - input, - residual, - residual_out, - x_scale, - y_scale, - weight, - eps, - # out_before_quant=out_before_quant, - ) - return output, residual_out, y_scale # , out_before_quant - - -def get_vals(): - - vals = [ - (1, 4), - (2, 10), - (8192, 4096), - (4096, 8192), - (8000, 8000), - (1, 31744), - (3, 65536), - (1, 131072), - (873, 1245), - (23153, 45), - (89999, 234), - ] - - return vals - - -@pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) -@pytest.mark.parametrize( - "M, N", - [(shape) for shape in get_vals()], -) -def test_rmsnorm_smoothquant(M, N, in_dtype_str, scale_dtype_str): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - in_dtype = arg_to_torch_dtype[in_dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] - - torch.manual_seed(0) - - x = torch.randn(M, N, device="cuda", dtype=in_dtype) - weight = torch.randn(N, device="cuda", dtype=in_dtype) - x_scale = torch.randn(N, device="cuda", dtype=scale_dtype) - - (y_torch, _, yscale_torch, *_) = run_torch( - x, weight, 1e-5, x_scale=x_scale, y_scale_dtype=scale_dtype - ) - (y_triton, _, yscale_triton, *_) = run_triton( - x, weight, 1e-5, x_scale=x_scale, y_scale_dtype=scale_dtype - ) - - triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) - triton.testing.assert_close(yscale_triton, yscale_torch, atol=1e-3, rtol=1e-3) - - -@pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) -@pytest.mark.parametrize( - "M, N", - [(shape) for shape in get_vals()], -) -def test_rmsnorm_dynamicquant(M, N, in_dtype_str, scale_dtype_str): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - in_dtype = arg_to_torch_dtype[in_dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] - - torch.manual_seed(0) - - x = torch.randn(M, N, device="cuda", dtype=in_dtype) - weight = torch.randn(N, device="cuda", dtype=in_dtype) - - (y_torch, _, yscale_torch, *_) = run_torch( - x, weight, 1e-5, y_scale_dtype=scale_dtype - ) - (y_triton, _, yscale_triton, *_) = run_triton( - x, weight, 1e-5, y_scale_dtype=scale_dtype - ) - - triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) - triton.testing.assert_close(yscale_triton, yscale_torch, atol=1e-3, rtol=1e-3) - - -@pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) -@pytest.mark.parametrize( - "M, N", - [(shape) for shape in get_vals()], -) -def test_rmsnorm_fused_add_smoothquant(M, N, in_dtype_str, scale_dtype_str): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - in_dtype = arg_to_torch_dtype[in_dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] - - torch.manual_seed(0) - - x = torch.randn(M, N, device="cuda", dtype=in_dtype) - weight = torch.randn(N, device="cuda", dtype=in_dtype) - res = torch.randn(M, N, device="cuda", dtype=in_dtype) - x_scale = torch.randn(N, device="cuda", dtype=scale_dtype) - - (y_torch, res_torch, yscale_torch, *_) = run_torch( - x, weight, 1e-5, residual=res, x_scale=x_scale, y_scale_dtype=scale_dtype - ) - (y_triton, res_triton, yscale_triton, *_) = run_triton( - x, weight, 1e-5, residual=res, x_scale=x_scale, y_scale_dtype=scale_dtype - ) - - triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) - triton.testing.assert_close(res_triton, res_torch, atol=1e-3, rtol=1e-3) - triton.testing.assert_close(yscale_triton, yscale_torch, atol=1e-3, rtol=1e-3) - - -@pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) -@pytest.mark.parametrize( - "M, N", - [(shape) for shape in get_vals()], -) -def test_rmsnorm_fused_add_dynamicquant(M, N, in_dtype_str, scale_dtype_str): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - in_dtype = arg_to_torch_dtype[in_dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] - - torch.manual_seed(0) - - x = torch.randn(M, N, device="cuda", dtype=in_dtype) - weight = torch.randn(N, device="cuda", dtype=in_dtype) - res = torch.randn(M, N, device="cuda", dtype=in_dtype) - - (y_torch, res_torch, yscale_torch, *_) = run_torch( - x, weight, 1e-5, residual=res, y_scale_dtype=scale_dtype - ) - (y_triton, res_triton, yscale_triton, *_) = run_triton( - x, weight, 1e-5, residual=res, y_scale_dtype=scale_dtype - ) - - triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) - triton.testing.assert_close(res_triton, res_torch, atol=1e-3, rtol=1e-3) - triton.testing.assert_close(yscale_triton, yscale_torch, atol=1e-3, rtol=1e-3) - - -@pytest.mark.parametrize("B", [1, 4, 8]) -@pytest.mark.parametrize("T", [128, 512, 2048]) -@pytest.mark.parametrize("D", [64, 4096]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) -def test_rms_norm_dynamic_per_token_fp8_quant( - B: int, T: int, D: int, dtype: torch.dtype -) -> None: - B_T = B * T - # Use integers to ensure consistent results across layouts, - # avoiding discrepancies in floating-point reductions with varying data layouts - x = torch.floor(torch.distributions.Uniform(-3, 3).sample((B_T, D))).to( - dtype=dtype, device="cuda" - ) - w = torch.floor(torch.distributions.Uniform(-3, 3).sample((D,))).to( - dtype=dtype, device="cuda" - ) - - EPS = 1e-6 - quant_dtype = torch.float8_e4m3fnuz - - xq_fused_triton = torch.empty(x.shape, dtype=quant_dtype, device="cuda") - x_scale_fused = torch.empty(x.shape[0], 1, dtype=torch.float32, device="cuda") - - x_normed = rmsnorm2d_fwd_with_dynamicquant( - xq_fused_triton, x, x_scale_fused, w, EPS, dump_rms_norm=True - ) - - ref_x_normed = torch_rmsnorm(x, w, dtype, EPS) - ref_xq, ref_x_scale = aiter.pertoken_quant(ref_x_normed, quant_dtype=quant_dtype) - - xq_dequant = xq_fused_triton.to(torch.float32) * x_scale_fused - xq_dequant = xq_dequant.to(dtype) - ref_xq_dequant = ref_xq.to(torch.float32) * ref_x_scale - ref_xq_dequant = xq_dequant.to(dtype) - - if dtype == torch.float32: - atol = 1e-5 - rtol = 1e-5 - else: - atol = 1e-2 - rtol = 1e-2 - - torch.testing.assert_close(xq_dequant, ref_xq_dequant, atol=atol, rtol=rtol) - torch.testing.assert_close(x_normed, ref_x_normed, atol=atol, rtol=rtol) From 2bc65118481dd7c0543dd2d6eec4c157b3c7fbe8 Mon Sep 17 00:00:00 2001 From: Lucas Santos Date: Tue, 1 Jul 2025 19:24:04 +0000 Subject: [PATCH 2/3] Refactor LayerNorm Triton tests --- op_tests/triton_tests/test_layernorm.py | 313 ++++++++++++++++-- op_tests/triton_tests/test_layernorm_quant.py | 287 ---------------- 2 files changed, 277 insertions(+), 323 deletions(-) delete mode 100644 op_tests/triton_tests/test_layernorm_quant.py diff --git a/op_tests/triton_tests/test_layernorm.py b/op_tests/triton_tests/test_layernorm.py index 0e44ece25e..5b749405cf 100644 --- a/op_tests/triton_tests/test_layernorm.py +++ b/op_tests/triton_tests/test_layernorm.py @@ -5,11 +5,20 @@ import torch import torch.nn.functional as F import pytest -from aiter.ops.triton.norm import layer_norm -from aiter.ops.triton.norm import layernorm2d_fwd_with_add +import aiter +from aiter.ops.triton.norm import ( + layer_norm, + layernorm2d_fwd_with_add, + layernorm2d_fwd_with_dynamicquant, + layernorm2d_fwd_with_smoothquant, + layernorm2d_fwd_with_add_dynamicquant, + layernorm2d_fwd_with_add_smoothquant, +) -def run_torch(input, weight, bias, eps, residual=None, x_bias=None): +def run_torch( + input, weight, bias, eps, residual=None, x_scale=None, y_scale_dtype=None +): if residual is None: residual_out = None output = F.layer_norm( @@ -28,39 +37,99 @@ def run_torch(input, weight, bias, eps, residual=None, x_bias=None): bias=bias, eps=eps, ) - return output, residual_out + if y_scale_dtype is None: + y_scale = None + else: + output, y_scale = aiter.pertoken_quant( + output, x_scale=x_scale, quant_dtype=torch.int8 + ) + return output, residual_out, y_scale -def run_triton(input, weight, bias, eps, residual=None, x_bias=None): - if residual is None: - residual_out = None - output = layer_norm(input, weight, bias, eps, x_bias) +def run_triton( + input, + weight, + bias, + eps, + residual=None, + x_bias=None, + x_scale=None, + y_scale_dtype=None, +): + if y_scale_dtype is None: + y_scale = None + if residual is None: + residual_out = None + output = layer_norm(input, weight, bias, eps, x_bias) + else: + residual_out = torch.empty_like(input) + output = torch.empty_like(input) + layernorm2d_fwd_with_add( + output, input, residual, residual_out, weight, bias, eps, x_bias + ) + elif x_scale is None: + y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") + output = torch.empty(input.shape, dtype=torch.int8, device="cuda") + if residual is None: + residual_out = None + layernorm2d_fwd_with_dynamicquant(output, input, y_scale, weight, bias, eps) + elif residual is not None: + residual_out = torch.empty_like(input) + layernorm2d_fwd_with_add_dynamicquant( + output, input, residual, residual_out, y_scale, weight, bias, eps + ) else: - residual_out = torch.empty_like(input) - output = torch.empty_like(input) - layernorm2d_fwd_with_add( - output, input, residual, residual_out, weight, bias, eps, x_bias - ) - return output, residual_out + y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") + output = torch.empty(input.shape, dtype=torch.int8, device="cuda") + if residual is None: + residual_out = None + layernorm2d_fwd_with_smoothquant( + output, input, x_scale, y_scale, weight, bias, eps + ) + elif residual is not None: + residual_out = torch.empty_like(input) + layernorm2d_fwd_with_add_smoothquant( + output, + input, + residual, + residual_out, + x_scale, + y_scale, + weight, + bias, + eps, + ) + return output, residual_out, y_scale -# pytest -@pytest.mark.parametrize("dtype_str", ["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize( - "M, N", - [ - (1823, 781), + +# TODO: Enable the commented shapes once the bug +# discussed in this issue is solved: +# https://github.com/ROCm/triton-internal/issues/843 +def get_vals(): + + vals = [ + # (1823, 781), (2, 128), (1, 4), (128, 2), (1, 128), - (8192, 8192), - (4096, 8192), + # (8192, 8192), + # (4096, 8192), (359, 1), (1, 359), (1, 131072), (1, 89999), - ], + ] + + return vals + + +# pytest +@pytest.mark.parametrize("dtype_str", ["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize( + "M, N", + [(shape) for shape in get_vals()], ) def test_layernorm(M, N, dtype_str, eps=1e-5): arg_to_torch_dtype = { @@ -92,19 +161,7 @@ def test_layernorm(M, N, dtype_str, eps=1e-5): @pytest.mark.parametrize("dtype_str", ["fp32", "fp16", "bf16"]) @pytest.mark.parametrize( "M, N", - [ - (1823, 781), - (2, 128), - (1, 4), - (128, 2), - (1, 128), - (8192, 8192), - (4096, 8192), - (359, 1), - (1, 359), - (1, 131072), - (1, 89999), - ], + [(shape) for shape in get_vals()], ) def test_fused_add_layernorm(M, N, dtype_str, eps=1e-5): arg_to_torch_dtype = { @@ -132,3 +189,187 @@ def test_fused_add_layernorm(M, N, dtype_str, eps=1e-5): triton.testing.assert_close(y_triton, y_torch, atol=atol, rtol=rtol) triton.testing.assert_close(res_triton, res_torch, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("dtype_str", ["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) +@pytest.mark.parametrize( + "M, N", + [(shape) for shape in get_vals()], +) +def test_layernorm_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1e-5): + arg_to_torch_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + dtype = arg_to_torch_dtype[dtype_str] + scale_dtype = arg_to_torch_dtype[scale_dtype_str] + torch.manual_seed(0) + + x = torch.randn(M, N, device="cuda", dtype=dtype) + w_shape = (N,) + b = torch.rand(w_shape, device="cuda", dtype=dtype) + w = torch.rand(w_shape, device="cuda", dtype=dtype) + x_scale = torch.rand(w_shape, device="cuda", dtype=scale_dtype) + + y_torch, _, y_scale_torch = run_torch( + x, w, b, eps, x_scale=x_scale, y_scale_dtype=scale_dtype + ) + y_triton, _, y_scale_triton = run_triton( + x, w, b, eps, x_scale=x_scale, y_scale_dtype=scale_dtype + ) + + xq_dequant = y_triton.to(torch.int32) * y_scale_triton + xq_dequant = xq_dequant.to(dtype) + ref_xq_dequant = y_torch.to(torch.int32) * y_scale_torch + ref_xq_dequant = xq_dequant.to(dtype) + + if dtype == torch.float32: + atol = 1e-5 + rtol = 1e-5 + else: + atol = 1e-2 + rtol = 1e-2 + + triton.testing.assert_close(xq_dequant, ref_xq_dequant, atol=atol, rtol=rtol) + triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) + triton.testing.assert_close(y_scale_triton, y_scale_torch, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("dtype_str", ["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) +@pytest.mark.parametrize( + "M, N", + [(shape) for shape in get_vals()], +) +def test_layernorm_dynamicquant(M, N, dtype_str, scale_dtype_str, eps=1e-3): + arg_to_torch_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + dtype = arg_to_torch_dtype[dtype_str] + scale_dtype = arg_to_torch_dtype[scale_dtype_str] + torch.manual_seed(0) + + x = torch.randn(M, N, device="cuda", dtype=dtype) + w_shape = (N,) + b = torch.rand(w_shape, device="cuda", dtype=dtype) + w = torch.rand(w_shape, device="cuda", dtype=dtype) + + # forward pass + y_torch, _, y_scale_torch = run_torch(x, w, b, eps, y_scale_dtype=scale_dtype) + y_triton, _, y_scale_triton = run_triton(x, w, b, eps, y_scale_dtype=scale_dtype) + + xq_dequant = y_triton.to(torch.int32) * y_scale_triton + xq_dequant = xq_dequant.to(dtype) + ref_xq_dequant = y_torch.to(torch.int32) * y_scale_torch + ref_xq_dequant = xq_dequant.to(dtype) + + if dtype == torch.float32: + atol = 1e-5 + rtol = 1e-5 + else: + atol = 1e-2 + rtol = 1e-2 + + triton.testing.assert_close(xq_dequant, ref_xq_dequant, atol=atol, rtol=rtol) + triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) + triton.testing.assert_close(y_scale_triton, y_scale_torch, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("dtype_str", ["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) +@pytest.mark.parametrize( + "M, N", + [(shape) for shape in get_vals()], +) +def test_layernorm_fused_add_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1e-5): + arg_to_torch_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + dtype = arg_to_torch_dtype[dtype_str] + scale_dtype = arg_to_torch_dtype[scale_dtype_str] + torch.manual_seed(0) + + x = torch.randn(M, N, device="cuda", dtype=dtype) + res = torch.randn(M, N, device="cuda", dtype=dtype) + w_shape = (N,) + b = torch.rand(w_shape, device="cuda", dtype=dtype) + w = torch.rand(w_shape, device="cuda", dtype=dtype) + x_scale = torch.rand(w_shape, device="cuda", dtype=scale_dtype) + + y_torch, res_torch, y_scale_torch = run_torch( + x, w, b, eps, residual=res, x_scale=x_scale, y_scale_dtype=scale_dtype + ) + y_triton, res_triton, y_scale_triton = run_triton( + x, w, b, eps, residual=res, x_scale=x_scale, y_scale_dtype=scale_dtype + ) + + xq_dequant = y_triton.to(torch.int32) * y_scale_triton + xq_dequant = xq_dequant.to(dtype) + ref_xq_dequant = y_torch.to(torch.int32) * y_scale_torch + ref_xq_dequant = xq_dequant.to(dtype) + + if dtype == torch.float32: + atol = 1e-5 + rtol = 1e-5 + else: + atol = 1e-2 + rtol = 1e-2 + + triton.testing.assert_close(xq_dequant, ref_xq_dequant, atol=atol, rtol=rtol) + triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) + triton.testing.assert_close(res_triton, res_torch, atol=atol, rtol=rtol) + triton.testing.assert_close(y_scale_triton, y_scale_torch, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("dtype_str", ["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) +@pytest.mark.parametrize( + "M, N", + [(shape) for shape in get_vals()], +) +def test_layernorm_fused_add_dynamicquant(M, N, dtype_str, scale_dtype_str, eps=1e-3): + arg_to_torch_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + dtype = arg_to_torch_dtype[dtype_str] + scale_dtype = arg_to_torch_dtype[scale_dtype_str] + torch.manual_seed(0) + + x = torch.randn(M, N, device="cuda", dtype=dtype) + res = torch.randn(M, N, device="cuda", dtype=dtype) + w_shape = (N,) + b = torch.rand(w_shape, device="cuda", dtype=dtype) + w = torch.rand(w_shape, device="cuda", dtype=dtype) + + # forward pass + y_torch, res_torch, y_scale_torch = run_torch( + x, w, b, eps, residual=res, y_scale_dtype=scale_dtype + ) + y_triton, res_triton, y_scale_triton = run_triton( + x, w, b, eps, residual=res, y_scale_dtype=scale_dtype + ) + + xq_dequant = y_triton.to(torch.int32) * y_scale_triton + xq_dequant = xq_dequant.to(dtype) + ref_xq_dequant = y_torch.to(torch.int32) * y_scale_torch + ref_xq_dequant = xq_dequant.to(dtype) + + if dtype == torch.float32: + atol = 1e-5 + rtol = 1e-5 + else: + atol = 1e-2 + rtol = 1e-2 + + triton.testing.assert_close(xq_dequant, ref_xq_dequant, atol=atol, rtol=rtol) + triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) + triton.testing.assert_close(res_triton, res_torch, atol=atol, rtol=rtol) + triton.testing.assert_close(y_scale_triton, y_scale_torch, atol=1e-3, rtol=1e-3) diff --git a/op_tests/triton_tests/test_layernorm_quant.py b/op_tests/triton_tests/test_layernorm_quant.py deleted file mode 100644 index 7c222f9bd8..0000000000 --- a/op_tests/triton_tests/test_layernorm_quant.py +++ /dev/null @@ -1,287 +0,0 @@ -import triton -import torch -import torch.nn.functional as F -import pytest -import aiter -from aiter.ops.triton.norm import ( - layernorm2d_fwd_with_dynamicquant, - layernorm2d_fwd_with_smoothquant, - layernorm2d_fwd_with_add_dynamicquant, - layernorm2d_fwd_with_add_smoothquant, -) - - -def run_torch( - input, weight, bias, eps, residual=None, x_scale=None, y_scale_dtype=None -): - if residual is None: - residual_out = None - output = F.layer_norm( - input=input, - normalized_shape=(input.shape[-1],), - weight=weight, - bias=bias, - eps=eps, - ) - else: - residual_out = input + residual - output = F.layer_norm( - input=residual_out, - normalized_shape=(input.shape[-1],), - weight=weight, - bias=bias, - eps=eps, - ) - if y_scale_dtype is None: - y_scale = None - else: - output, y_scale = aiter.pertoken_quant( - output, x_scale=x_scale, quant_dtype=torch.int8 - ) - return output, residual_out, y_scale - - -def run_triton( - input, weight, bias, eps, residual=None, x_scale=None, y_scale_dtype=None -): - if x_scale is None: - y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") - output = torch.empty(input.shape, dtype=torch.int8, device="cuda") - if residual is None: - residual_out = None - layernorm2d_fwd_with_dynamicquant(output, input, y_scale, weight, bias, eps) - elif residual is not None: - residual_out = torch.empty_like(input) - layernorm2d_fwd_with_add_dynamicquant( - output, input, residual, residual_out, y_scale, weight, bias, eps - ) - else: - y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") - output = torch.empty(input.shape, dtype=torch.int8, device="cuda") - if residual is None: - residual_out = None - layernorm2d_fwd_with_smoothquant( - output, input, x_scale, y_scale, weight, bias, eps - ) - elif residual is not None: - residual_out = torch.empty_like(input) - layernorm2d_fwd_with_add_smoothquant( - output, - input, - residual, - residual_out, - x_scale, - y_scale, - weight, - bias, - eps, - ) - - return output, residual_out, y_scale - - -# TODO: Enable the commented shapes once the bug -# discussed in this issue is solved: -# https://github.com/ROCm/triton-internal/issues/843 -def get_vals(): - - vals = [ - # (1823, 781), - (2, 128), - (1, 4), - (128, 2), - (1, 128), - # (8192, 8192), - # (4096, 8192), - (359, 1), - (1, 359), - (1, 131072), - (1, 89999), - ] - - return vals - - -# pytest -@pytest.mark.parametrize("dtype_str", ["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) -@pytest.mark.parametrize( - "M, N", - [(shape) for shape in get_vals()], -) -def test_layernorm_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1e-5): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - dtype = arg_to_torch_dtype[dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) - - x = torch.randn(M, N, device="cuda", dtype=dtype) - w_shape = (N,) - b = torch.rand(w_shape, device="cuda", dtype=dtype) - w = torch.rand(w_shape, device="cuda", dtype=dtype) - x_scale = torch.rand(w_shape, device="cuda", dtype=scale_dtype) - - y_torch, _, y_scale_torch = run_torch( - x, w, b, eps, x_scale=x_scale, y_scale_dtype=scale_dtype - ) - y_triton, _, y_scale_triton = run_triton( - x, w, b, eps, x_scale=x_scale, y_scale_dtype=scale_dtype - ) - - xq_dequant = y_triton.to(torch.int32) * y_scale_triton - xq_dequant = xq_dequant.to(dtype) - ref_xq_dequant = y_torch.to(torch.int32) * y_scale_torch - ref_xq_dequant = xq_dequant.to(dtype) - - if dtype == torch.float32: - atol = 1e-5 - rtol = 1e-5 - else: - atol = 1e-2 - rtol = 1e-2 - - triton.testing.assert_close(xq_dequant, ref_xq_dequant, atol=atol, rtol=rtol) - triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) - triton.testing.assert_close(y_scale_triton, y_scale_torch, atol=1e-3, rtol=1e-3) - - -@pytest.mark.parametrize("dtype_str", ["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) -@pytest.mark.parametrize( - "M, N", - [(shape) for shape in get_vals()], -) -def test_layernorm_dynamicquant(M, N, dtype_str, scale_dtype_str, eps=1e-3): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - dtype = arg_to_torch_dtype[dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) - - x = torch.randn(M, N, device="cuda", dtype=dtype) - w_shape = (N,) - b = torch.rand(w_shape, device="cuda", dtype=dtype) - w = torch.rand(w_shape, device="cuda", dtype=dtype) - - # forward pass - y_torch, _, y_scale_torch = run_torch(x, w, b, eps, y_scale_dtype=scale_dtype) - y_triton, _, y_scale_triton = run_triton(x, w, b, eps, y_scale_dtype=scale_dtype) - - xq_dequant = y_triton.to(torch.int32) * y_scale_triton - xq_dequant = xq_dequant.to(dtype) - ref_xq_dequant = y_torch.to(torch.int32) * y_scale_torch - ref_xq_dequant = xq_dequant.to(dtype) - - if dtype == torch.float32: - atol = 1e-5 - rtol = 1e-5 - else: - atol = 1e-2 - rtol = 1e-2 - - triton.testing.assert_close(xq_dequant, ref_xq_dequant, atol=atol, rtol=rtol) - triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) - triton.testing.assert_close(y_scale_triton, y_scale_torch, atol=1e-3, rtol=1e-3) - - -@pytest.mark.parametrize("dtype_str", ["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) -@pytest.mark.parametrize( - "M, N", - [(shape) for shape in get_vals()], -) -def test_layernorm_fused_add_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1e-5): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - dtype = arg_to_torch_dtype[dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) - - x = torch.randn(M, N, device="cuda", dtype=dtype) - res = torch.randn(M, N, device="cuda", dtype=dtype) - w_shape = (N,) - b = torch.rand(w_shape, device="cuda", dtype=dtype) - w = torch.rand(w_shape, device="cuda", dtype=dtype) - x_scale = torch.rand(w_shape, device="cuda", dtype=scale_dtype) - - y_torch, res_torch, y_scale_torch = run_torch( - x, w, b, eps, residual=res, x_scale=x_scale, y_scale_dtype=scale_dtype - ) - y_triton, res_triton, y_scale_triton = run_triton( - x, w, b, eps, residual=res, x_scale=x_scale, y_scale_dtype=scale_dtype - ) - - xq_dequant = y_triton.to(torch.int32) * y_scale_triton - xq_dequant = xq_dequant.to(dtype) - ref_xq_dequant = y_torch.to(torch.int32) * y_scale_torch - ref_xq_dequant = xq_dequant.to(dtype) - - if dtype == torch.float32: - atol = 1e-5 - rtol = 1e-5 - else: - atol = 1e-2 - rtol = 1e-2 - - triton.testing.assert_close(xq_dequant, ref_xq_dequant, atol=atol, rtol=rtol) - triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) - triton.testing.assert_close(res_triton, res_torch, atol=atol, rtol=rtol) - triton.testing.assert_close(y_scale_triton, y_scale_torch, atol=1e-3, rtol=1e-3) - - -@pytest.mark.parametrize("dtype_str", ["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize("scale_dtype_str", ["fp32"]) -@pytest.mark.parametrize( - "M, N", - [(shape) for shape in get_vals()], -) -def test_layernorm_fused_add_dynamicquant(M, N, dtype_str, scale_dtype_str, eps=1e-3): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - dtype = arg_to_torch_dtype[dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) - - x = torch.randn(M, N, device="cuda", dtype=dtype) - res = torch.randn(M, N, device="cuda", dtype=dtype) - w_shape = (N,) - b = torch.rand(w_shape, device="cuda", dtype=dtype) - w = torch.rand(w_shape, device="cuda", dtype=dtype) - - # forward pass - y_torch, res_torch, y_scale_torch = run_torch( - x, w, b, eps, residual=res, y_scale_dtype=scale_dtype - ) - y_triton, res_triton, y_scale_triton = run_triton( - x, w, b, eps, residual=res, y_scale_dtype=scale_dtype - ) - - xq_dequant = y_triton.to(torch.int32) * y_scale_triton - xq_dequant = xq_dequant.to(dtype) - ref_xq_dequant = y_torch.to(torch.int32) * y_scale_torch - ref_xq_dequant = xq_dequant.to(dtype) - - if dtype == torch.float32: - atol = 1e-5 - rtol = 1e-5 - else: - atol = 1e-2 - rtol = 1e-2 - - triton.testing.assert_close(xq_dequant, ref_xq_dequant, atol=atol, rtol=rtol) - triton.testing.assert_close(y_triton, y_torch, atol=1, rtol=0) - triton.testing.assert_close(res_triton, res_torch, atol=atol, rtol=rtol) - triton.testing.assert_close(y_scale_triton, y_scale_torch, atol=1e-3, rtol=1e-3) From 61decc6f79145861388c9751cdf5610d6b996cef Mon Sep 17 00:00:00 2001 From: Lucas Santos Date: Wed, 2 Jul 2025 18:50:13 +0000 Subject: [PATCH 3/3] Updated to use commom dictionary str_to_torch_dtype --- op_tests/triton_tests/test_layernorm.py | 51 +++++----------------- op_tests/triton_tests/test_rmsnorm.py | 57 ++++++++----------------- 2 files changed, 28 insertions(+), 80 deletions(-) diff --git a/op_tests/triton_tests/test_layernorm.py b/op_tests/triton_tests/test_layernorm.py index 5b749405cf..0ddd3e8d64 100644 --- a/op_tests/triton_tests/test_layernorm.py +++ b/op_tests/triton_tests/test_layernorm.py @@ -6,6 +6,7 @@ import torch.nn.functional as F import pytest import aiter +from aiter.ops.triton.utils.types import str_to_torch_dtype from aiter.ops.triton.norm import ( layer_norm, layernorm2d_fwd_with_add, @@ -132,12 +133,7 @@ def get_vals(): [(shape) for shape in get_vals()], ) def test_layernorm(M, N, dtype_str, eps=1e-5): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - dtype = arg_to_torch_dtype[dtype_str] + dtype = str_to_torch_dtype[dtype_str] torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) w_shape = (N,) @@ -164,12 +160,7 @@ def test_layernorm(M, N, dtype_str, eps=1e-5): [(shape) for shape in get_vals()], ) def test_fused_add_layernorm(M, N, dtype_str, eps=1e-5): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - dtype = arg_to_torch_dtype[dtype_str] + dtype = str_to_torch_dtype[dtype_str] torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) res = torch.randn(M, N, device="cuda", dtype=dtype) @@ -198,13 +189,8 @@ def test_fused_add_layernorm(M, N, dtype_str, eps=1e-5): [(shape) for shape in get_vals()], ) def test_layernorm_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1e-5): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - dtype = arg_to_torch_dtype[dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] + dtype = str_to_torch_dtype[dtype_str] + scale_dtype = str_to_torch_dtype[scale_dtype_str] torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) @@ -244,13 +230,8 @@ def test_layernorm_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1e-5): [(shape) for shape in get_vals()], ) def test_layernorm_dynamicquant(M, N, dtype_str, scale_dtype_str, eps=1e-3): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - dtype = arg_to_torch_dtype[dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] + dtype = str_to_torch_dtype[dtype_str] + scale_dtype = str_to_torch_dtype[scale_dtype_str] torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) @@ -286,13 +267,8 @@ def test_layernorm_dynamicquant(M, N, dtype_str, scale_dtype_str, eps=1e-3): [(shape) for shape in get_vals()], ) def test_layernorm_fused_add_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1e-5): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - dtype = arg_to_torch_dtype[dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] + dtype = str_to_torch_dtype[dtype_str] + scale_dtype = str_to_torch_dtype[scale_dtype_str] torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) @@ -334,13 +310,8 @@ def test_layernorm_fused_add_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1 [(shape) for shape in get_vals()], ) def test_layernorm_fused_add_dynamicquant(M, N, dtype_str, scale_dtype_str, eps=1e-3): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - dtype = arg_to_torch_dtype[dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] + dtype = str_to_torch_dtype[dtype_str] + scale_dtype = str_to_torch_dtype[scale_dtype_str] torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) diff --git a/op_tests/triton_tests/test_rmsnorm.py b/op_tests/triton_tests/test_rmsnorm.py index 2c20ce4e61..3a563149b9 100644 --- a/op_tests/triton_tests/test_rmsnorm.py +++ b/op_tests/triton_tests/test_rmsnorm.py @@ -5,6 +5,7 @@ import torch import triton import aiter +from aiter.ops.triton.utils.types import str_to_torch_dtype from aiter.ops.triton.rmsnorm import ( rms_norm, rmsnorm2d_fwd_with_add, @@ -127,12 +128,8 @@ def get_vals(): [(shape) for shape in get_vals()], ) def test_rmsnorm(M, N, in_dtype_str): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - in_dtype = arg_to_torch_dtype[in_dtype_str] + + in_dtype = str_to_torch_dtype[in_dtype_str] out_dtype = in_dtype torch.manual_seed(0) @@ -179,12 +176,8 @@ def test_rmsnorm(M, N, in_dtype_str): [(shape) for shape in get_vals()], ) def test_fused_add_rmsnorm(M, N, in_dtype_str): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - in_dtype = arg_to_torch_dtype[in_dtype_str] + + in_dtype = str_to_torch_dtype[in_dtype_str] out_dtype = in_dtype torch.manual_seed(0) @@ -235,13 +228,9 @@ def test_fused_add_rmsnorm(M, N, in_dtype_str): [(shape) for shape in get_vals()], ) def test_rmsnorm_smoothquant(M, N, in_dtype_str, scale_dtype_str): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - in_dtype = arg_to_torch_dtype[in_dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] + + in_dtype = str_to_torch_dtype[in_dtype_str] + scale_dtype = str_to_torch_dtype[scale_dtype_str] torch.manual_seed(0) @@ -267,13 +256,9 @@ def test_rmsnorm_smoothquant(M, N, in_dtype_str, scale_dtype_str): [(shape) for shape in get_vals()], ) def test_rmsnorm_dynamicquant(M, N, in_dtype_str, scale_dtype_str): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - in_dtype = arg_to_torch_dtype[in_dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] + + in_dtype = str_to_torch_dtype[in_dtype_str] + scale_dtype = str_to_torch_dtype[scale_dtype_str] torch.manual_seed(0) @@ -298,13 +283,9 @@ def test_rmsnorm_dynamicquant(M, N, in_dtype_str, scale_dtype_str): [(shape) for shape in get_vals()], ) def test_rmsnorm_fused_add_smoothquant(M, N, in_dtype_str, scale_dtype_str): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - in_dtype = arg_to_torch_dtype[in_dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] + + in_dtype = str_to_torch_dtype[in_dtype_str] + scale_dtype = str_to_torch_dtype[scale_dtype_str] torch.manual_seed(0) @@ -332,13 +313,9 @@ def test_rmsnorm_fused_add_smoothquant(M, N, in_dtype_str, scale_dtype_str): [(shape) for shape in get_vals()], ) def test_rmsnorm_fused_add_dynamicquant(M, N, in_dtype_str, scale_dtype_str): - arg_to_torch_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - in_dtype = arg_to_torch_dtype[in_dtype_str] - scale_dtype = arg_to_torch_dtype[scale_dtype_str] + + in_dtype = str_to_torch_dtype[in_dtype_str] + scale_dtype = str_to_torch_dtype[scale_dtype_str] torch.manual_seed(0)