1212def benchmark_npu (fn , num_iterations = 100 , num_warmup_iterations = 50 ):
1313 """
1414 Benchmark function for NPU operations
15-
15+
1616 Args:
1717 fn: Function to benchmark
1818 num_iterations: Number of timing iterations
1919 num_warmup_iterations: Number of warmup iterations
20-
20+
2121 Returns:
2222 float: Minimum elapsed time in seconds
2323 """
@@ -41,19 +41,26 @@ def benchmark_npu(fn, num_iterations=100, num_warmup_iterations=50):
4141
4242
4343def get_masked_input_and_mask_ref (
44- input_ : torch .Tensor , org_vocab_start_index : int ,
45- org_vocab_end_index : int , num_org_vocab_padding : int ,
46- added_vocab_start_index : int ,
47- added_vocab_end_index : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
44+ input_ : torch .Tensor ,
45+ org_vocab_start_index : int ,
46+ org_vocab_end_index : int ,
47+ num_org_vocab_padding : int ,
48+ added_vocab_start_index : int ,
49+ added_vocab_end_index : int ,
50+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
4851 """Reference implementation for verification"""
49- org_vocab_mask = (input_ >= org_vocab_start_index ) & (input_ <
50- org_vocab_end_index )
52+ org_vocab_mask = (input_ >= org_vocab_start_index ) & (input_ < org_vocab_end_index )
5153 added_vocab_mask = (input_ >= added_vocab_start_index ) & (
52- input_ < added_vocab_end_index )
53- added_offset = added_vocab_start_index - (
54- org_vocab_end_index - org_vocab_start_index ) - num_org_vocab_padding
55- valid_offset = (org_vocab_start_index *
56- org_vocab_mask ) + (added_offset * added_vocab_mask )
54+ input_ < added_vocab_end_index
55+ )
56+ added_offset = (
57+ added_vocab_start_index
58+ - (org_vocab_end_index - org_vocab_start_index )
59+ - num_org_vocab_padding
60+ )
61+ valid_offset = (org_vocab_start_index * org_vocab_mask ) + (
62+ added_offset * added_vocab_mask
63+ )
5764 vocab_mask = org_vocab_mask | added_vocab_mask
5865 masked_input = vocab_mask * (input_ - valid_offset )
5966 return masked_input , ~ vocab_mask
@@ -94,21 +101,25 @@ def test_get_masked_input_and_mask(
94101
95102 # Define reference function
96103 def ref_fn ():
97- return get_masked_input_and_mask_ref (input_tensor ,
98- test_case ["org_start" ],
99- test_case ["org_end" ],
100- test_case ["padding" ],
101- test_case ["added_start" ],
102- test_case ["added_end" ])
104+ return get_masked_input_and_mask_ref (
105+ input_tensor ,
106+ test_case ["org_start" ],
107+ test_case ["org_end" ],
108+ test_case ["padding" ],
109+ test_case ["added_start" ],
110+ test_case ["added_end" ],
111+ )
103112
104113 # Define custom function
105114 def custom_fn ():
106- return torch .ops ._C .get_masked_input_and_mask (input_tensor ,
107- test_case ["org_start" ],
108- test_case ["org_end" ],
109- test_case ["padding" ],
110- test_case ["added_start" ],
111- test_case ["added_end" ])
115+ return torch .ops ._C .get_masked_input_and_mask (
116+ input_tensor ,
117+ test_case ["org_start" ],
118+ test_case ["org_end" ],
119+ test_case ["padding" ],
120+ test_case ["added_start" ],
121+ test_case ["added_end" ],
122+ )
112123
113124 # Get results for correctness testing
114125 ref_masked_input , ref_mask = ref_fn ()
@@ -120,9 +131,9 @@ def custom_fn():
120131
121132 # Print performance results
122133 print ("\n Performance Results:" )
123- print (f"Reference implementation: { ref_time * 1000 :.3f} ms" )
124- print (f"Custom implementation: { custom_time * 1000 :.3f} ms" )
125- print (f"Speedup: { ref_time / custom_time :.2f} x" )
134+ print (f"Reference implementation: { ref_time * 1000 :.3f} ms" )
135+ print (f"Custom implementation: { custom_time * 1000 :.3f} ms" )
136+ print (f"Speedup: { ref_time / custom_time :.2f} x" )
126137
127138 # Compare results for correctness
128139 ref_masked_input = ref_masked_input .to (dtype )
@@ -136,9 +147,12 @@ def custom_fn():
136147 ref_masked_input ,
137148 rtol = 1e-5 ,
138149 atol = 1e-5 ,
139- msg = f"Masked input mismatch for case: { test_case } " )
140- torch .testing .assert_close (custom_mask ,
141- ref_mask ,
142- rtol = 1e-5 ,
143- atol = 1e-5 ,
144- msg = f"Mask mismatch for case: { test_case } " )
150+ msg = f"Masked input mismatch for case: { test_case } " ,
151+ )
152+ torch .testing .assert_close (
153+ custom_mask ,
154+ ref_mask ,
155+ rtol = 1e-5 ,
156+ atol = 1e-5 ,
157+ msg = f"Mask mismatch for case: { test_case } " ,
158+ )
0 commit comments