@@ -30,12 +30,12 @@ def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
3030 D: Input dim
3131 L: LoRA Dim
3232 N: N LoRAs
33-
33+
3434 Outputs:
3535 inputs: torch.Tensor - shape (T, D)
3636 loras: torch.Tensor - shape (N, 1, L, D)
3737 idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)
38-
38+
3939 ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
4040 """
4141 torch .manual_seed (seed )
@@ -84,3 +84,28 @@ def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
8484
8585 # Compare with reference output
8686 assert torch .allclose (output , ref_output , rtol = 1e-2 , atol = 1e-2 )
87+
88+ # Parameterize tests with various shapes and dtypes
89+ @pytest .mark .parametrize ("T" , N_TOKENS )
90+ @pytest .mark .parametrize ("D" , HIDDEN_SIZES )
91+ @pytest .mark .parametrize ("L" , RANKS )
92+ @pytest .mark .parametrize ("N" , NUM_LORA )
93+ @pytest .mark .parametrize ("dtype" , DTYPES )
94+ @pytest .mark .parametrize ("seed" , [0 ])
95+ def test_lora_laning_correctness (T , D , L , N , dtype , seed ):
96+ inputs , loras_a , idxs , _ = generate_test_data (T , D , L , N , seed , dtype )
97+ _ , loras_b , _ , _ = generate_test_data (T , L , D , N , seed , dtype )
98+
99+ r1 = ref_bgmv (inputs , loras_a , idxs )
100+ r2 = ref_bgmv (r1 , loras_b , idxs )
101+
102+ o1 = torch .ops .xla .bgmv_shrink (inputs , loras_a , idxs )
103+ o2 = torch .ops .xla .bgmv_expand (
104+ o1 ,
105+ loras_b .transpose (2 , 3 ),
106+ idxs ,
107+ True
108+ )
109+
110+ # Compare with reference output
111+ assert torch .allclose (o2 , r2 , rtol = 1e-2 , atol = 1e-2 )
0 commit comments