@@ -233,8 +233,18 @@ def forward(
233233 offsets : Optional [torch .Tensor ] = None ,
234234 ) -> torch .Tensor :
235235 # we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
236- q , k , v = self .qkv_proj (hidden_states ).chunk (3 , dim = - 1 )
237- query , key = self .rope .forward_native (positions , q , k , offsets )
236+ qkv = self .qkv_proj (hidden_states )
237+ q , k , v = qkv .chunk (3 , dim = - 1 )
238+ query , key = torch .ops ._C .rotary_embedding (
239+ positions ,
240+ q ,
241+ k ,
242+ self .rope .head_size ,
243+ self .rope .cos_sin_cache ,
244+ self .rope .is_neox_style ,
245+ )
246+ query = query .view (q .shape )
247+ key = key .view (k .shape )
238248 o = self .o_proj (query )
239249 return o
240250
@@ -257,14 +267,16 @@ def test_capture_rotary_embedding_in_aclgraph(
257267 dtype : torch .dtype ,
258268 seed : int ,
259269 device : str ,
260- max_position_embeddings : int ,
261- base : int ,
270+ max_position_embeddings : int = 8192 ,
271+ base : int = 10000 ,
262272):
263273 """Test if the rotary embedding can be captured in aclgraph."""
264274 torch .manual_seed (seed )
265275 torch .set_default_device (device )
276+ if rotary_dim is None :
277+ rotary_dim = head_size
266278 model = ModelwithRotaryEmbedding (
267- hidden_size = num_tokens ,
279+ hidden_size = num_heads * head_size ,
268280 num_heads = num_heads ,
269281 head_size = head_size ,
270282 rotary_dim = rotary_dim ,
@@ -274,13 +286,20 @@ def test_capture_rotary_embedding_in_aclgraph(
274286 dtype = dtype ,
275287 )
276288
289+ def custom_op_checking_backend (gm : torch .fx .GraphModule , example_input ):
290+ # Validate if the rotary_embedding custom kernel is indeed inside the graph by
291+ # string match
292+ graph = str (gm .graph )
293+ assert "_C.rotary_embedding" in graph
294+ return gm
295+
277296 static_positions = torch .randint (0 , max_position_embeddings ,
278297 (num_tokens , ))
279298 static_hidden_states = torch .randn (num_tokens ,
280299 num_heads * head_size ,
281300 dtype = dtype ,
282301 device = "npu" )
283- compiled_model = torch .compile (model )
302+ compiled_model = torch .compile (model , backend = custom_op_checking_backend )
284303 stream = torch .npu .Stream ()
285304 stream .wait_stream (torch .npu .current_stream ())
286305 with torch .npu .stream (stream ):
0 commit comments