@@ -493,6 +493,126 @@ def mean_batch_invariant(input,
493493 return result
494494
495495
496+ @triton .jit
497+ def _rms_norm_kernel (
498+ input_ptr ,
499+ weight_ptr ,
500+ output_ptr ,
501+ input_row_stride ,
502+ output_row_stride ,
503+ n_cols ,
504+ eps ,
505+ BLOCK_SIZE : tl .constexpr ,
506+ ):
507+ """
508+ Compute RMS normalization along the last dimension of a 2D tensor.
509+ RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
510+ Each block handles one row of the input tensor.
511+ """
512+ row_idx = tl .program_id (0 ).to (tl .int64 )
513+ row_start_ptr = input_ptr + row_idx * input_row_stride
514+ output_row_start_ptr = output_ptr + row_idx * output_row_stride
515+
516+ # Step 1: Compute sum of squares
517+ sum_sq = 0.0
518+ for col_offset in range (0 , n_cols , BLOCK_SIZE ):
519+ col_idx = col_offset + tl .arange (0 , BLOCK_SIZE )
520+ mask = col_idx < n_cols
521+
522+ vals = tl .load (row_start_ptr + col_idx , mask = mask , other = 0.0 )
523+ sq_vals = vals * vals
524+ sum_sq += tl .sum (tl .where (mask , sq_vals , 0.0 ))
525+
526+ # Step 2: Compute RMS (root mean square)
527+ mean_sq = sum_sq / n_cols
528+ rms = tl .sqrt (mean_sq + eps )
529+ inv_rms = 1.0 / rms
530+
531+ # Step 3: Normalize and apply weight
532+ for col_offset in range (0 , n_cols , BLOCK_SIZE ):
533+ col_idx = col_offset + tl .arange (0 , BLOCK_SIZE )
534+ mask = col_idx < n_cols
535+ vals = tl .load (row_start_ptr + col_idx , mask = mask , other = 0.0 )
536+ weight = tl .load (weight_ptr + col_idx , mask = mask , other = 1.0 )
537+ output = vals * inv_rms * weight
538+ tl .store (output_row_start_ptr + col_idx , output , mask = mask )
539+
540+
541+ def rms_norm (input : torch .Tensor ,
542+ weight : torch .Tensor ,
543+ eps : float = 1e-6 ) -> torch .Tensor :
544+ """
545+ Compute RMS normalization using Triton kernel.
546+
547+ RMS Norm normalizes the input by the root mean square and scales by weight:
548+ output = input / sqrt(mean(input^2) + eps) * weight
549+
550+ Args:
551+ input: Input tensor of shape (..., hidden_size)
552+ weight: Weight tensor of shape (hidden_size,)
553+ eps: Small constant for numerical stability
554+
555+ Returns:
556+ Tensor with RMS normalization applied along the last dimension
557+ """
558+ assert input .is_cuda , "Input must be a CUDA tensor"
559+ assert weight .is_cuda , "Weight must be a CUDA tensor"
560+ assert weight .dim () == 1 , "Weight must be 1-dimensional"
561+ assert input .shape [- 1 ] == weight .shape [0 ], (
562+ f"Input last dimension ({ input .shape [- 1 ]} ) must match "
563+ f"weight dimension ({ weight .shape [0 ]} )" )
564+
565+ # Flatten all dimensions except the last one
566+ original_shape = input .shape
567+ input_2d = input .reshape (- 1 , input .shape [- 1 ])
568+ input_2d = input_2d .contiguous ()
569+ weight = weight .contiguous ()
570+
571+ n_rows , n_cols = input_2d .shape
572+
573+ output = torch .empty_like (input_2d )
574+ BLOCK_SIZE = 1024
575+ grid = (n_rows , )
576+ _rms_norm_kernel [grid ](
577+ input_2d ,
578+ weight ,
579+ output ,
580+ input_2d .stride (0 ),
581+ output .stride (0 ),
582+ n_cols ,
583+ eps ,
584+ BLOCK_SIZE = BLOCK_SIZE ,
585+ )
586+ return output .reshape (original_shape )
587+
588+
589+ def rms_norm_batch_invariant (input : torch .Tensor ,
590+ weight : torch .Tensor ,
591+ eps : float = 1e-6 ) -> torch .Tensor :
592+ """
593+ Batch-invariant wrapper for RMS normalization.
594+
595+ This function provides a deterministic, batch-invariant implementation
596+ of RMS normalization for use with the batch_invariant mode.
597+
598+ Args:
599+ input: Input tensor of shape (..., hidden_size)
600+ weight: Weight tensor of shape (hidden_size,)
601+ eps: Small constant for numerical stability
602+
603+ Returns:
604+ RMS normalized tensor
605+ """
606+ return rms_norm (input , weight , eps = eps )
607+
608+
609+ def linear_batch_invariant (input , weight , bias = None ):
610+ output = torch .mm (input , weight .t ())
611+ if bias is not None :
612+ output = output + bias
613+ return output
614+
615+
496616_batch_invariant_MODE = False
497617_batch_invariant_LIB = None
498618
@@ -510,6 +630,7 @@ def enable_batch_invariant_mode():
510630 _batch_invariant_LIB = torch .library .Library ("aten" , "IMPL" )
511631 _batch_invariant_LIB .impl ("aten::mm" , mm_batch_invariant , "CUDA" )
512632 _batch_invariant_LIB .impl ("aten::addmm" , addmm_batch_invariant , "CUDA" )
633+ _batch_invariant_LIB .impl ("aten::linear" , linear_batch_invariant , "CUDA" )
513634 _batch_invariant_LIB .impl ("aten::_log_softmax" ,
514635 _log_softmax_batch_invariant , "CUDA" )
515636 _batch_invariant_LIB .impl ("aten::mean.dim" , mean_batch_invariant , "CUDA" )
0 commit comments