@@ -70,7 +70,7 @@ def test_rotary_embedding(
7070 if rotary_dim is None :
7171 rotary_dim = head_size
7272 rope = get_rope (head_size , rotary_dim , max_position , base , is_neox_style )
73- rope = rope .to (dtype = dtype )
73+ rope = rope .to (dtype = dtype , device = torch . get_default_device () )
7474
7575 positions = torch .randint (0 , max_position , (batch_size , seq_len ))
7676 query_shape = tensor_shape_fn (batch_size , seq_len , num_heads , head_size )
@@ -125,7 +125,7 @@ def test_batched_rotary_embedding(
125125 "rope_type" : "linear" ,
126126 "factor" : (1 , )
127127 })
128- rope = rope .to (dtype = dtype )
128+ rope = rope .to (dtype = dtype , device = torch . get_default_device () )
129129
130130 positions = torch .randint (0 , max_position , (batch_size , seq_len ))
131131 query_shape = tensor_shape_fn (batch_size , seq_len , num_heads , head_size )
@@ -184,7 +184,7 @@ def test_batched_rotary_embedding_multi_lora(
184184 "rope_type" : "linear" ,
185185 "factor" : tuple (scaling_factors )
186186 })
187- rope = rope .to (dtype = dtype )
187+ rope = rope .to (dtype = dtype , device = torch . get_default_device () )
188188
189189 positions = torch .randint (0 , max_position , (batch_size , seq_len ))
190190 query = torch .randn (batch_size ,
0 commit comments