diff --git a/python/perf-kernels/rmsnorm.py b/python/perf-kernels/rmsnorm.py index a04408b9cfd5..5b41390d9f37 100644 --- a/python/perf-kernels/rmsnorm.py +++ b/python/perf-kernels/rmsnorm.py @@ -46,35 +46,67 @@ def get_autotune_config(): @triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) @triton.jit -def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon, +def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, eps, BLOCK_SIZE: tl.constexpr): - row_start = tl.program_id(0) - row_step = tl.num_programs(0) - col_offsets = tl.arange(0, BLOCK_SIZE) + row_idx = tl.program_id(0) + + #Calculate squared mean by block + row_start_ptr = input_ptr + row_idx * input_row_stride + row_sum = 0.0 + n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1 + #tl.device_print("n_cols_blks",n_cols_blks) + for b in tl.range(0, n_cols_blks): + col_offsets = b*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + row_block = tl.load(input_ptrs, cache_modifier=".cg") + row_block = row_block * row_block #square every value the block + row_sum += (tl.sum(row_block, axis=-1) / n_cols) #tl.sum across row + + col_offsets = n_cols_blks*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets mask = col_offsets < n_cols - for row_idx in tl.range(row_start, n_rows, row_step): - row_start_ptr = input_ptr + row_idx * input_row_stride + row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") + row_block = row_block * row_block #square every value the block + row_sum += (tl.sum(row_block, axis=-1) / n_cols) #tl.sum across row + + + row_norm = row_sum + eps + row_norm = tl.rsqrt(row_norm) + + #Blocked normalization + output_row_start_ptr = output_ptr + row_idx * output_row_stride + #for b in tl.range(0, n_cols, BLOCK_SIZE): + for b in tl.range(0, n_cols_blks): + col_offsets = b*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") - g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0) - row_norm = row * row #square each value - row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1) - row_norm = row_norm / n_cols #divide by n_cols - row_norm = row_norm + epsilon #add epsilon - row_norm = tl.rsqrt(row_norm) #take rsqrt, this is normalization value - rms_norm = row * row_norm #multiply each x by normalization value - rms_norm = rms_norm * g #element wise multiplication with g - - output_row_start_ptr = output_ptr + row_idx * output_row_stride + row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block of input + g = tl.load(g_ptr + col_offsets, cache_modifier=".cg") #load block of g + output = row_block * row_norm #element wise multiply with rms_norm + output = output * g #element wise multiplication with g + output_ptrs = output_row_start_ptr + col_offsets - output_ptrs = tl.multiple_of(output_ptrs, (16, )) - tl.store(output_ptrs, rms_norm, mask=mask) + tl.store(output_ptrs, output) + + col_offsets = n_cols_blks*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < n_cols + row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") #load block of g + output = row_block * row_norm #element wise multiply with rms_norm + output = output * g #element wise multiplication with g + + #tl.device_print("output",output) + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, output, mask=mask) + def triton_rmsnorm(x, g, epsilon=1e-6): n_rows, n_cols = x.shape - BLOCK_SIZE = triton.next_power_of_2(n_cols) + #Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise, + #performance can drop significantly for larger n_cols. + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols)) y = torch.empty_like(x, device='cuda') @@ -84,7 +116,6 @@ def triton_rmsnorm(x, g, epsilon=1e-6): return y - def torch_rmsnorm(x, g): M, N = x.shape if hasattr(torch.nn, 'RMSNorm'): @@ -95,15 +126,17 @@ def torch_rmsnorm(x, g): rms_norm = torch.div(x, rms.unsqueeze(1).repeat(1, N)) * g return rms_norm - +# yapf: disable @pytest.mark.parametrize('M, N', [ - (1, 4), - (2, 10), - (8192, 4096), - (4096, 8192), - (1, 8192), - (873, 1245), -]) + (1, 4), + (2, 10), + (8192, 4096), + (4096, 8192), + (1, 8192), + (873, 1245), + (1, 98304) + ]) +# yapf: enable def test_rmsnorm(M, N): torch.manual_seed(0) x = torch.randn(M, N, device='cuda') @@ -114,7 +147,6 @@ def test_rmsnorm(M, N): assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) - #Benchmark arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32}