@@ -692,5 +692,55 @@ def test_error2():
692692 self .assertRaises (AssertionError , test_error2 )
693693
694694
695+ @unittest .skipIf (
696+ not core .is_compiled_with_cuda () and not paddle .is_compiled_with_rocm (),
697+ "core is not compiled with CUDA or ROCM " ,
698+ )
699+ class TestFusedRotaryPositionEmbeddingZeroSize (unittest .TestCase ):
700+ def setUp (self ):
701+ self .dtype = "float32"
702+ self .qkv_shape = [0 , 1 , 8 , 8 ]
703+ self .sin_cos_shape = [1 , 1 , 1 , 8 ]
704+
705+ def init_data (self ):
706+ self .q = paddle .randn (self .qkv_shape , dtype = self .dtype )
707+ self .k = paddle .randn (self .qkv_shape , dtype = self .dtype )
708+ self .v = paddle .randn (self .qkv_shape , dtype = self .dtype )
709+ self .q .stop_gradient = False
710+ self .k .stop_gradient = False
711+ self .v .stop_gradient = False
712+ self .sin = paddle .sin (
713+ paddle .randn (self .sin_cos_shape , dtype = self .dtype )
714+ )
715+ self .cos = paddle .cos (
716+ paddle .randn (self .sin_cos_shape , dtype = self .dtype )
717+ )
718+
719+ def _test_forward_backward (self ):
720+ out_q , out_k , out_v = fused_rotary_position_embedding (
721+ self .q ,
722+ self .k ,
723+ self .v ,
724+ sin = self .sin ,
725+ cos = self .cos ,
726+ use_neox_rotary_style = False ,
727+ )
728+ out = out_q + out_k + out_v
729+ out .backward ()
730+ np .testing .assert_allclose (
731+ self .q .shape , self .q .grad .shape , rtol = 1e-05 , atol = 1e-06
732+ )
733+ np .testing .assert_allclose (
734+ self .k .shape , self .k .grad .shape , rtol = 1e-05 , atol = 1e-06
735+ )
736+ np .testing .assert_allclose (
737+ self .v .shape , self .v .grad .shape , rtol = 1e-05 , atol = 1e-06
738+ )
739+
740+ def test_zero_size (self ):
741+ self .init_data ()
742+ self ._test_forward_backward ()
743+
744+
695745if __name__ == "__main__" :
696746 unittest .main ()
0 commit comments