1+ import torch
2+ from torch .library import Library
3+
4+ lib = Library ("_C" , "IMPL" )
5+
6+ def register_meta_if_necessary (ns :str , op_name : str , fn , overload : str = "" ):
7+ if overload != "" :
8+ op_name = op_name + "." + overload
9+ schema_to_find = ns + "::" + op_name
10+ meta_impl_list = torch ._C ._dispatch_get_registrations_for_dispatch_key ("Meta" )
11+ if schema_to_find in meta_impl_list :
12+ return
13+ lib .impl (op_name , fn , "Meta" )
14+
15+ def rotary_embedding_meta (
16+ positions : torch .Tensor ,
17+ query : torch .Tensor ,
18+ key : torch .Tensor ,
19+ head_size : int ,
20+ cos_sin_cache : torch .Tensor ,
21+ is_neox : bool ):
22+
23+ num_tokens = positions .numel ()
24+ query_hidden_size = query .numel () / num_tokens
25+ key_hidden_size = key .numel () / num_tokens
26+ num_heads = query_hidden_size / head_size
27+ num_kv_heads = key_hidden_size / head_size
28+
29+ query_dst = torch .empty_like (query ).view (num_tokens , num_heads , head_size )
30+ key_dst = torch .empty_like (key ).view (num_tokens , num_kv_heads , head_size )
31+ return query_dst , key_dst
32+
33+
34+ def get_masked_input_and_mask_meta (
35+ input : torch .Tensor ,
36+ org_vocab_start_index : int ,
37+ org_vocab_end_index : int ,
38+ num_org_vocab_padding : int ,
39+ added_vocab_start_index : int ,
40+ added_vocab_end_index : int ):
41+
42+ masked_input = torch .empty_like (input )
43+ mask = torch .empty_like (input ).to (torch .bool )
44+
45+ return masked_input , mask
46+
47+
48+
49+ register_meta_if_necessary ("_C" , "rotary_embedding" , rotary_embedding_meta )
50+ register_meta_if_necessary ("_C" , "get_masked_input_and_mask" , get_masked_input_and_mask_meta )
0 commit comments