1616# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which
1717# is essential for supporting `torch.compile` and aclgraph.
1818
19-
2019# How to add a new meta implementation in Python:
2120# -------------------------------------
2221# 1. Write a Python function that takes the same arguments as your operator, and returns
3534# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask.
3635#
3736# 5. When developing new custom ops, always provide a meta implementation to enable tracing,
38- # export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
37+ # export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
3938# and aclgraph.
4039#
4140# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
4241
4342lib = Library ("_C" , "IMPL" )
4443
4544
46- def register_meta_if_necessary (ns :str , op_name : str , fn , overload : str = "" ):
47- if overload != "" :
48- op_name = op_name + "." + overload
49- schema_to_find = ns + "::" + op_name
50- meta_impl_list = torch ._C ._dispatch_get_registrations_for_dispatch_key ("Meta" )
51- if schema_to_find in meta_impl_list :
52- return
53- lib .impl (op_name , fn , "Meta" )
45+ def register_meta_if_necessary (ns : str , op_name : str , fn , overload : str = "" ):
46+ if overload != "" :
47+ op_name = op_name + "." + overload
48+ schema_to_find = ns + "::" + op_name
49+ meta_impl_list = torch ._C ._dispatch_get_registrations_for_dispatch_key (
50+ "Meta" )
51+ if schema_to_find in meta_impl_list :
52+ return
53+ lib .impl (op_name , fn , "Meta" )
5454
55- def rotary_embedding_meta (
56- positions : torch .Tensor ,
57- query : torch .Tensor ,
58- key : torch .Tensor ,
59- head_size : int ,
60- cos_sin_cache : torch .Tensor ,
61- is_neox : bool ):
6255
63- num_tokens = positions .numel ()
64- query_hidden_size = query .numel () / num_tokens
65- key_hidden_size = key .numel () / num_tokens
66- num_heads = query_hidden_size / head_size
67- num_kv_heads = key_hidden_size / head_size
56+ def rotary_embedding_meta (positions : torch .Tensor , query : torch .Tensor ,
57+ key : torch .Tensor , head_size : int ,
58+ cos_sin_cache : torch .Tensor , is_neox : bool ):
6859
69- query_dst = torch .empty_like (query ).view (num_tokens , num_heads , head_size )
70- key_dst = torch .empty_like (key ).view (num_tokens , num_kv_heads , head_size )
71- return query_dst , key_dst
60+ num_tokens = positions .numel ()
61+ query_hidden_size = query .numel () // num_tokens
62+ key_hidden_size = key .numel () // num_tokens
63+ num_heads = query_hidden_size // head_size
64+ num_kv_heads = key_hidden_size // head_size
7265
66+ query_dst = torch .empty_like (query ).view (num_tokens , num_heads , head_size )
67+ key_dst = torch .empty_like (key ).view (num_tokens , num_kv_heads , head_size )
68+ return query_dst , key_dst
7369
74- def get_masked_input_and_mask_meta (
75- input : torch .Tensor ,
76- org_vocab_start_index : int ,
77- org_vocab_end_index : int ,
78- num_org_vocab_padding : int ,
79- added_vocab_start_index : int ,
80- added_vocab_end_index : int ):
8170
82- masked_input = torch .empty_like (input )
83- mask = torch .empty_like (input ).to (torch .bool )
71+ def get_masked_input_and_mask_meta (input : torch .Tensor ,
72+ org_vocab_start_index : int ,
73+ org_vocab_end_index : int ,
74+ num_org_vocab_padding : int ,
75+ added_vocab_start_index : int ,
76+ added_vocab_end_index : int ):
8477
85- return masked_input , mask
78+ masked_input = torch .empty_like (input )
79+ mask = torch .empty_like (input ).to (torch .bool )
8680
81+ return masked_input , mask
8782
8883
8984register_meta_if_necessary ("_C" , "rotary_embedding" , rotary_embedding_meta )
90- register_meta_if_necessary ("_C" , "get_masked_input_and_mask" , get_masked_input_and_mask_meta )
85+ register_meta_if_necessary ("_C" , "get_masked_input_and_mask" ,
86+ get_masked_input_and_mask_meta )
0 commit comments