77import torch_xla
88import torch_xla .core .xla_model as xm
99import torch_xla .debug .metrics as met
10- from torch_xla .experimental .custom_kernel import gmm , _make_group_metadata , _histogram , tgmm , gmm_backward , GMM
10+ from torch_xla .experimental .custom_kernel import gmm , tgmm , gmm_backward , GMM
1111from torch_xla import runtime as xr
1212from torch_xla ._internal import tpu
1313
@@ -120,10 +120,11 @@ def test_gmm(self):
120120 # torch.compiled version of the gmm will cache the payload in dynamo layer
121121 # hence won't trigger the trace_pallas cache
122122 if test_cache and gmm_func != compiled_gmm :
123- met . clear_counters ()
123+ old_cnt = xr . get_num_cached_compilation_graph ()
124124 # execute the same gmm func, expected to hit the cache
125125 out = gmm_func (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
126- self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
126+ new_cnt = xr .get_num_cached_compilation_graph ()
127+ self .assertEqual (old_cnt , new_cnt )
127128 self .assertTrue (torch .allclose (ref_out , out .cpu ()))
128129
129130 # Make sure gmm doesn't fallback.
@@ -155,173 +156,16 @@ def test_gmm_bf16(self):
155156 # torch.compiled version of the gmm will cache the payload in dynamo layer
156157 # hence won't trigger the trace_pallas cache
157158 if test_cache and gmm_func != compiled_gmm :
158- met . clear_counters ()
159+ old_cnt = xr . get_num_cached_compilation_graph ()
159160 # execute the same gmm func, expected to hit the cache
160161 out = gmm_func (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
161- self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
162+ new_cnt = xr .get_num_cached_compilation_graph ()
163+ self .assertEqual (old_cnt , new_cnt )
162164 self .assertTrue (torch .allclose (ref_out , out .cpu ()))
163165
164166 # Make sure gmm doesn't fallback.
165167 self .assertEqual (len (torch_xla ._XLAC ._get_executed_fallback_ops ()), 0 )
166168
167- @unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
168- def test_make_group_metadata (self ):
169- from jax .experimental .pallas .ops .tpu .megablox .gmm import make_group_metadata as jax_make_group_metadata
170- met .clear_all ()
171-
172- test_grids = [
173- {
174- 'group_sizes' : [8 , 8 , 8 , 8 ],
175- 'm' : 32 ,
176- 'tm' : 8
177- },
178- {
179- 'group_sizes' : [2 , 14 , 8 , 8 ],
180- 'm' : 32 ,
181- 'tm' : 8
182- },
183- {
184- 'group_sizes' : [16 , 0 , 8 , 8 ],
185- 'm' : 32 ,
186- 'tm' : 8
187- },
188- {
189- 'group_sizes' : [2 , 0 , 14 , 16 ],
190- 'm' : 32 ,
191- 'tm' : 8
192- },
193- {
194- 'group_sizes' : [8 , 12 , 0 , 12 ],
195- 'm' : 32 ,
196- 'tm' : 8
197- },
198- {
199- 'group_sizes' : [6 , 12 , 0 , 14 ],
200- 'm' : 32 ,
201- 'tm' : 8
202- },
203- {
204- 'group_sizes' : [6 , 12 , 0 , 14 ],
205- 'm' : 32 ,
206- 'tm' : 4
207- },
208- {
209- 'group_sizes' : [377 , 588 , 153 , 1638 , 3261 , 5890 , 996 , 3481 ],
210- 'm' : 16384 ,
211- 'tm' : 128
212- },
213- ]
214-
215- for test_grid in test_grids :
216- jax_meta , jax_num_tiles = jax_make_group_metadata (
217- group_sizes = jnp .array (test_grid ['group_sizes' ]),
218- m = test_grid ['m' ],
219- tm = test_grid ['tm' ],
220- start_group = 0 ,
221- num_nonzero_groups = len (test_grid ['group_sizes' ]),
222- )
223-
224- torch_meta = _make_group_metadata (
225- group_sizes = torch .tensor (test_grid ['group_sizes' ]).to (
226- torch .int32 ).to ("xla" ),
227- m = test_grid ['m' ],
228- tm = test_grid ['tm' ],
229- visit_empty_groups = True ,
230- )
231-
232- for i in range (len (jax_meta )):
233- self .assertTrue (
234- torch .all (
235- torch .from_numpy (np .array (jax_meta [i ])) == torch_meta [i ].cpu ()))
236- self .assertEqual (jax_num_tiles , torch_meta [- 1 ].cpu ().item ())
237-
238- # Make sure _make_group_metadata doesn't fallback.
239- self .assertNotIn ("aten::" , met .short_metrics_report ())
240-
241- def test_histogram (self ):
242- test_grids = [
243- {
244- 'input' : [1 , 4 , 4 , 1 , 2 , 3 ],
245- 'min' : 1 ,
246- 'max' : 4 ,
247- },
248- {
249- 'input' : [1 , 4 , 4 , 1 , 2 , 3 ],
250- 'min' : 2 ,
251- 'max' : 3 ,
252- },
253- {
254- 'input' : [1 , 4 , 4 , 1 , 2 , 3 ],
255- 'min' : 0 ,
256- 'max' : 5 ,
257- },
258- {
259- 'input' : [],
260- 'min' : 0 ,
261- 'max' : 5 ,
262- },
263- ]
264-
265- for test_grid in test_grids :
266- torch_chart = torch .histc (
267- torch .tensor (test_grid ['input' ], dtype = torch .float ),
268- bins = test_grid ['max' ] - test_grid ['min' ] + 1 ,
269- min = test_grid ['min' ],
270- max = test_grid ['max' ],
271- )
272-
273- chart = _histogram (
274- torch .tensor (test_grid ['input' ], dtype = torch .int32 ).to ("xla" ),
275- min = test_grid ['min' ],
276- max = test_grid ['max' ],
277- )
278-
279- self .assertEqual (chart .dtype , torch .int32 )
280- self .assertTrue (torch .all (torch_chart == chart .cpu ()))
281-
282- def test_histogram_raise (self ):
283- with self .assertRaisesRegex (AssertionError ,
284- "input must be of torch.int32 dtype." ):
285- _histogram (
286- torch .tensor ([1 , 4 , 4 , 1 , 2 , 3 ], dtype = torch .float ),
287- min = 4 ,
288- max = 5 ,
289- )
290-
291- with self .assertRaisesRegex (AssertionError ,
292- "min must be less than or equal to max." ):
293- _histogram (
294- torch .tensor ([1 , 4 , 4 , 1 , 2 , 3 ], dtype = torch .int32 ),
295- min = 4 ,
296- max = 3 ,
297- )
298-
299- def test_sorting_input (self ):
300- met .clear_all ()
301- top2 = torch .tensor ([[0 , 2 ], [1 , 3 ], [1 , 2 ], [2 , 3 ]]).to ("xla" )
302-
303- # We want to create one big batch of tokens that has all top-k choices in it.
304- # Our tokens will thus be duplicated k-times in the batch. To do this we,
305- # first flatten the expert choices list and argsort it. This gives us an array
306- # of length B * K. We then create a tiled arange of size B * K and index
307- # into the expert choices list. This will give us the set of indices we need
308- # to gather from the xs to create this big batch.
309- top_flat = top2 .flatten ()
310- lhs_order = top_flat .argsort ()
311- lhs_reverse_order = lhs_order .argsort ()
312- lhs_indices = torch .arange (
313- top2 .shape [0 ], device = "xla" ).repeat_interleave (2 )[lhs_order ]
314- group_sizes = _histogram (top_flat .to (torch .int32 ), 0 , 3 )
315- torch_xla .sync ()
316-
317- # Make sure it doesn't fallback.
318- self .assertNotIn ("aten::" , met .short_metrics_report ())
319- self .assertTrue (
320- torch .all (lhs_indices == torch .tensor ([0 , 1 , 2 , 0 , 3 , 2 , 1 , 3 ],
321- device = "xla" )))
322- self .assertTrue (
323- torch .all (group_sizes == torch .tensor ([1 , 2 , 3 , 2 ], device = "xla" )))
324-
325169 @unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
326170 def test_tgmm (self ):
327171 met .clear_all ()
@@ -343,10 +187,11 @@ def test_tgmm(self):
343187
344188 out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
345189 if test_cache :
346- met . clear_counters ()
190+ old_cnt = xr . get_num_cached_compilation_graph ()
347191 # execute the same gmm func, expected to hit the cache
348192 out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
349- self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
193+ new_cnt = xr .get_num_cached_compilation_graph ()
194+ self .assertEqual (new_cnt , old_cnt )
350195 self .assertTrue (torch .allclose (ref_out , out .cpu ()))
351196
352197 # Make sure tgmm doesn't fallback.
@@ -373,10 +218,11 @@ def test_tgmm_bf16(self):
373218
374219 out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
375220 if test_cache :
376- met . clear_counters ()
221+ old_cnt = xr . get_num_cached_compilation_graph ()
377222 # execute the same gmm func, expected to hit the cache
378223 out = tgmm (lhs .to ("xla" ), rhs .to ("xla" ), group_sizes .to ("xla" ))
379- self .assertEqual (met .counter_value ('trace_pallas_cache_hit' ), 1 )
224+ new_cnt = xr .get_num_cached_compilation_graph ()
225+ self .assertEqual (new_cnt , old_cnt )
380226 self .assertTrue (torch .allclose (ref_out , out .cpu ()))
381227
382228 # Make sure tgmm doesn't fallback.
@@ -393,7 +239,7 @@ def test_gmm_backward(self):
393239 lhs_dtype = rhs_dtype = torch .bfloat16
394240
395241 for test_cache in [False , True ]:
396- met . clear_all ()
242+ old_cnt = xr . get_num_cached_compilation_graph ()
397243 lhs = torch .rand (m , k , dtype = lhs_dtype , requires_grad = True )
398244 rhs = torch .rand (num_groups , k , n , dtype = rhs_dtype , requires_grad = True )
399245 group_sizes = self ._group_sizes_strategy (m = m , num_groups = num_groups )
@@ -409,8 +255,9 @@ def test_gmm_backward(self):
409255 group_sizes .to ("xla" ))
410256 # same gmm/tgmm was run for the `test_cache=False` case so the
411257 # cache should be populated now
258+ new_cnt = xr .get_num_cached_compilation_graph ()
412259 if test_cache :
413- self .assertEqual (met . counter_value ( 'trace_pallas_cache_hit' ), 2 )
260+ self .assertEqual (new_cnt , old_cnt )
414261
415262 self .assertTrue (torch .allclose (lhs .grad , grad_lhs .cpu ()))
416263 self .assertTrue (torch .allclose (rhs .grad , grad_rhs .cpu ()))
0 commit comments