1- # https://github.com/pytorch/xla/blob/20899c7258680a36cd3bec1c820e8a52c16a4bbf/torch_xla/csrc/autocast_mode.cpp#L29
1+ import torch
2+ import torch ._C
3+ from torch .utils import _pytree as pytree
4+
5+ def call_with_next_key (op , args , kwargs ):
6+ return op (* args , ** kwargs )
7+
8+ target_precision = torch .bfloat16
9+
10+ def lower_precision_fp (op ):
11+ def inner (* args , ** kwargs ):
12+ target_precision = torch .get_autocast_dtype ('privateuseone' )
13+ autocast_keyset = torch ._C .DispatchKeySet (torch ._C .DispatchKey .AutocastPrivateUse1 )
14+ with torch ._C ._ExcludeDispatchKeyGuard (autocast_keyset ):
15+ is_float_tensor = lambda a : isinstance (a , torch .Tensor ) and a .is_floating_point ()
16+ args , kwargs = pytree .tree_map_only (
17+ is_float_tensor ,
18+ lambda x : x .to (target_precision ),
19+ (args , kwargs ))
20+ return op (* args , ** kwargs )
21+ return inner
222
323
24+ lib = torch .library .Library ('aten' , 'FRAGMENT' )
25+ my_lib = torch .library .Library ('_' , 'IMPL' , 'AutocastPrivateUse1' )
26+ my_lib .fallback (torch .library .fallthrough_kernel )
427
5- TORCH_LIBRARY_IMPL (aten , AutocastXLA , m ) {
6- // lower_precision_fp cast policy
7- KERNEL_XLA (conv1d , lower_precision_fp )
8- KERNEL_XLA2 (conv1d , padding , lower_precision_fp )
9- KERNEL_XLA (conv2d , lower_precision_fp )
10- KERNEL_XLA2 (conv2d , padding , lower_precision_fp )
11- KERNEL_XLA (conv3d , lower_precision_fp )
12- KERNEL_XLA2 (conv3d , padding , lower_precision_fp )
13- KERNEL_XLA (bmm , lower_precision_fp )
14- KERNEL_XLA (mm , lower_precision_fp )
15- KERNEL_XLA (baddbmm , lower_precision_fp )
16- KERNEL_XLA (addmm , lower_precision_fp )
17- KERNEL_XLA (addbmm , lower_precision_fp )
18- KERNEL_XLA (linear , lower_precision_fp )
19- KERNEL_XLA (matmul , lower_precision_fp )
20- KERNEL_XLA (conv_tbc , lower_precision_fp )
21- KERNEL_XLA (conv_transpose1d , lower_precision_fp )
22- KERNEL_XLA2 (conv_transpose2d , input , lower_precision_fp )
23- KERNEL_XLA2 (conv_transpose3d , input , lower_precision_fp )
24- KERNEL_XLA (prelu , lower_precision_fp )
25- KERNEL_XLA (relu , lower_precision_fp )
26- KERNEL_XLA (max_pool2d , lower_precision_fp )
27- KERNEL_XLA (einsum , lower_precision_fp )
28- // Disable `scaled_dot_product_attention` for now since it causes
29- // undefined symbol with official torch whl .
30- // KERNEL_XLA (scaled_dot_product_attention , lower_precision_fp )
3128
32- // fp32 cast policy
33- // Commented out ops are included in the AutoCastCPU Policy ,
34- // but not lowered . Enable if op is lowered .
35- KERNEL_XLA (batch_norm , fp32 )
36- KERNEL_XLA (_softmax , fp32 )
37- KERNEL_XLA2 (softmax , int , fp32 )
38- KERNEL_XLA2 (softmax , Dimname , fp32 )
39- KERNEL_XLA2 (log_softmax , int , fp32 )
40- KERNEL_XLA2 (log_softmax , Dimname , fp32 )
41- KERNEL_XLA (binary_cross_entropy , fp32 )
42- // KERNEL_XLA (grid_sampler , fp32 )
43- // KERNEL_XLA (polar , fp32 )
44- KERNEL_XLA2 (pow , Tensor_Scalar , fp32 )
45- KERNEL_XLA (prod , fp32 )
46- KERNEL_XLA2 (prod , dim_int , fp32 )
47- KERNEL_XLA2 (prod , dim_Dimname , fp32 )
48- // KERNEL_XLA (quantile , fp32 )
49- // KERNEL_XLA2 (quantile , scalar , fp32 )
50- // KERNEL_XLA (nanquantile , fp32 )
51- // KERNEL_XLA2 (nanquantile , scalar , fp32 )
52- // KERNEL_XLA (stft , fp32 )
53- // KERNEL_XLA2 (stft , center , fp32 )
54- KERNEL_XLA (cdist , fp32 )
55- // KERNEL_XLA (grid_sampler_2d , fp32 )
56- // KERNEL_XLA (grid_sampler_3d , fp32 )
57- KERNEL_XLA (trace , fp32 )
58- // KERNEL_XLA (view_as_complex , fp32 )
59- KERNEL_XLA (cholesky , fp32 )
60- KERNEL_XLA (cholesky_inverse , fp32 )
61- KERNEL_XLA (cholesky_solve , fp32 )
62- KERNEL_XLA (inverse , fp32 )
63- // KERNEL_XLA (lu_solve , fp32 )
64- // KERNEL_XLA (orgqr , fp32 )
65- // KERNEL_XLA (ormqr , fp32 )
66- // KERNEL_XLA (pinverse , fp32 )
67- KERNEL_XLA (reflection_pad1d , fp32 )
68- KERNEL_XLA (reflection_pad2d , fp32 )
69- KERNEL_XLA (replication_pad1d , fp32 )
70- KERNEL_XLA (replication_pad2d , fp32 )
71- KERNEL_XLA (replication_pad3d , fp32 )
72- KERNEL_XLA (mse_loss , fp32 )
73- KERNEL_XLA (cosine_embedding_loss , fp32 )
74- KERNEL_XLA (nll_loss , fp32 )
75- KERNEL_XLA (nll_loss2d , fp32 )
76- KERNEL_XLA (hinge_embedding_loss , fp32 )
77- // KERNEL_XLA (poisson_nll_loss , fp32 )
78- KERNEL_XLA (smooth_l1_loss , fp32 )
79- KERNEL_XLA (cross_entropy_loss , fp32 )
80- KERNEL_XLA (l1_loss , fp32 )
81- // KERNEL_XLA (huber_loss , fp32 )
82- KERNEL_XLA (margin_ranking_loss , fp32 )
83- KERNEL_XLA (soft_margin_loss , fp32 )
84- KERNEL_XLA (triplet_margin_loss , fp32 )
85- KERNEL_XLA (multi_margin_loss , fp32 )
86- KERNEL_XLA2 (ctc_loss , IntList , fp32 )
87- KERNEL_XLA2 (ctc_loss , Tensor , fp32 )
88- KERNEL_XLA (kl_div , fp32 )
89- KERNEL_XLA (multilabel_margin_loss , fp32 )
90- KERNEL_XLA (binary_cross_entropy_with_logits , fp32 )
91- // KERNEL_XLA (fft_fft , fp32 )
92- // KERNEL_XLA (fft_ifft , fp32 )
93- // KERNEL_XLA (fft_fft2 , fp32 )
94- // KERNEL_XLA (fft_ifft2 , fp32 )
95- // KERNEL_XLA (fft_fftn , fp32 )
96- // KERNEL_XLA (fft_ifftn , fp32 )
97- // KERNEL_XLA (fft_rfft , fp32 )
98- // KERNEL_XLA (fft_irfft , fp32 )
99- // KERNEL_XLA (fft_rfft2 , fp32 )
100- // KERNEL_XLA (fft_irfft2 , fp32 )
101- // KERNEL_XLA (fft_rfftn , fp32 )
102- // KERNEL_XLA (fft_irfftn , fp32 )
103- // KERNEL_XLA (fft_hfft , fp32 )
104- // KERNEL_XLA (fft_ihfft , fp32 )
105- // KERNEL_XLA (linalg_cond , fp32 )
106- // KERNEL_XLA2 (linalg_cond , p_str , fp32 )
107- // KERNEL_XLA (linalg_matrix_rank , fp32 )
108- // KERNEL_XLA2 (linalg_matrix_rank , tol_tensor , fp32 )
109- // KERNEL_XLA2 (linalg_matrix_rank , atol_rtol_tensor , fp32 )
110- // KERNEL_XLA2 (linalg_matrix_rank , atol_rtol_float , fp32 )
111- // KERNEL_XLA (linalg_solve , fp32 )
112- // KERNEL_XLA (linalg_cholesky , fp32 )
113- // KERNEL_XLA (linalg_svdvals , fp32 )
114- // KERNEL_XLA (linalg_eigvals , fp32 )
115- // KERNEL_XLA (linalg_eigvalsh , fp32 )
116- // KERNEL_XLA (linalg_inv , fp32 )
117- // KERNEL_XLA (linalg_householder_product , fp32 )
118- // KERNEL_XLA (linalg_tensorinv , fp32 )
119- // KERNEL_XLA (linalg_tensorsolve , fp32 )
120- // KERNEL_XLA (fake_quantize_per_tensor_affine , fp32 )
121- // KERNEL_XLA (geqrf , fp32 )
122- // KERNEL_XLA (_lu_with_info , fp32 )
123- KERNEL_XLA (qr , fp32 )
124- KERNEL_XLA (svd , fp32 )
125- KERNEL_XLA (triangular_solve , fp32 )
126- KERNEL_XLA (multilabel_margin_loss_forward , fp32 )
127- // KERNEL_XLA (linalg_qr , fp32 )
128- // KERNEL_XLA (linalg_cholesky_ex , fp32 )
129- KERNEL_XLA (linalg_svd , fp32 )
130- // KERNEL_XLA (linalg_eig , fp32 )
131- // KERNEL_XLA (linalg_eigh , fp32 )
132- // KERNEL_XLA (linalg_lstsq , fp32 )
133- KERNEL_XLA (linalg_inv_ex , fp32 )
29+ for op in [torch .ops .aten .conv1d .default ,
30+ torch .ops .aten .conv1d .padding ,
31+ torch .ops .aten .conv2d .default ,
32+ torch .ops .aten .conv2d .padding ,
33+ torch .ops .aten .conv3d .default ,
34+ torch .ops .aten .bmm .default ,
35+ torch .ops .aten .mm .default ,
36+ torch .ops .aten .baddbmm .default ,
37+ torch .ops .aten .addmm .default ,
38+ torch .ops .aten .addbmm .default ,
39+ torch .ops .aten .linear .default ,
40+ torch .ops .aten .matmul .default ,
41+ torch .ops .aten .conv_tbc .default ,
42+ torch .ops .aten .conv_transpose1d .default ,
43+ torch .ops .aten .conv_transpose2d .input ,
44+ torch .ops .aten .conv_transpose3d .input ,
45+ torch .ops .aten .prelu .default ,
46+ torch .ops .aten .relu .default ,
47+ torch .ops .aten .max_pool2d .default ,
48+ torch .ops .aten .einsum .default ,
49+ ]:
50+ lib .impl (op .name (), lower_precision_fp (op ), "AutocastPrivateUse1" , with_keyset = False )
51+
52+ # https://github.com/pytorch/xla/blob/20899c7258680a36cd3bec1c820e8a52c16a4bbf/torch_xla/csrc/autocast_mode.cpp#L29
53+ # enum class CastPolicy : uint8_t {
54+ # lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
55+ # // running the op. Currently, lower_precision_fp is
56+ # // fp16 for AutocastCUDA, and is defined by user
57+ # // (default bf16) for AutocastCPU or other device.
58+ # fp32, // Cast all inputs to at::kFloat before running the op.
59+ # fp32_set_opt_dtype, // Treats functions (like softmax) that
60+ # // 1. we'd like to run in fp32 and
61+ # // 2. have a std::optional<ScalarType> arg that controls
62+ # // the output type.
63+ # // fp32_set_opt_dtype wrappers' policy is: if the output
64+ # // type is already set, don't touch it, otherwise, set
65+ # // it to at::kFloat.
66+ # fp32_append_dtype, // Treats functions (like norm) that
67+ # // 1. we'd like to run in fp32 and
68+ # // 2. have some overloads that accept an output type and
69+ # // other overloads that don't.
70+ # // fp32_append_dtype wrappers wrap the overloads that don't
71+ # // have an output dtype.
72+ # // The wrapper policy is: append at::kFloat to the args,
73+ # // and redispatch to the type-aware overload.
74+ # promote, // Run in the widest dtype among several args.
75+ # };
76+ # TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
77+ # // lower_precision_fp cast policy
78+ # KERNEL_XLA(conv1d, lower_precision_fp)
79+ # KERNEL_XLA2(conv1d, padding, lower_precision_fp)
80+ # KERNEL_XLA(conv2d, lower_precision_fp)
81+ # KERNEL_XLA2(conv2d, padding, lower_precision_fp)
82+ # KERNEL_XLA(conv3d, lower_precision_fp)
83+ # KERNEL_XLA2(conv3d, padding, lower_precision_fp)
84+ # KERNEL_XLA(bmm, lower_precision_fp)
85+ # KERNEL_XLA(mm, lower_precision_fp)
86+ # KERNEL_XLA(baddbmm, lower_precision_fp)
87+ # KERNEL_XLA(addmm, lower_precision_fp)
88+ # KERNEL_XLA(addbmm, lower_precision_fp)
89+ # KERNEL_XLA(linear, lower_precision_fp)
90+ # KERNEL_XLA(matmul, lower_precision_fp)
91+ # KERNEL_XLA(conv_tbc, lower_precision_fp)
92+ # KERNEL_XLA(conv_transpose1d, lower_precision_fp)
93+ # KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp)
94+ # KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp)
95+ # KERNEL_XLA(prelu, lower_precision_fp)
96+ # KERNEL_XLA(relu, lower_precision_fp)
97+ # KERNEL_XLA(max_pool2d, lower_precision_fp)
98+ # KERNEL_XLA(einsum, lower_precision_fp)
99+ # // Disable `scaled_dot_product_attention` for now since it causes
100+ # // undefined symbol with official torch whl.
101+ # // KERNEL_XLA(scaled_dot_product_attention, lower_precision_fp)
102+
103+ # // fp32 cast policy
104+ # // Commented out ops are included in the AutoCastCPU Policy,
105+ # // but not lowered. Enable if op is lowered.
106+ # KERNEL_XLA(batch_norm, fp32)
107+ # KERNEL_XLA(_softmax, fp32)
108+ # KERNEL_XLA2(softmax, int, fp32)
109+ # KERNEL_XLA2(softmax, Dimname, fp32)
110+ # KERNEL_XLA2(log_softmax, int, fp32)
111+ # KERNEL_XLA2(log_softmax, Dimname, fp32)
112+ # KERNEL_XLA(binary_cross_entropy, fp32)
113+ # // KERNEL_XLA(grid_sampler, fp32)
114+ # // KERNEL_XLA(polar, fp32)
115+ # KERNEL_XLA2(pow, Tensor_Scalar, fp32)
116+ # KERNEL_XLA(prod, fp32)
117+ # KERNEL_XLA2(prod, dim_int, fp32)
118+ # KERNEL_XLA2(prod, dim_Dimname, fp32)
119+ # // KERNEL_XLA(quantile, fp32)
120+ # // KERNEL_XLA2(quantile, scalar, fp32)
121+ # // KERNEL_XLA(nanquantile, fp32)
122+ # // KERNEL_XLA2(nanquantile, scalar, fp32)
123+ # // KERNEL_XLA(stft, fp32)
124+ # // KERNEL_XLA2(stft, center, fp32)
125+ # KERNEL_XLA(cdist, fp32)
126+ # // KERNEL_XLA(grid_sampler_2d, fp32)
127+ # // KERNEL_XLA(grid_sampler_3d, fp32)
128+ # KERNEL_XLA(trace, fp32)
129+ # // KERNEL_XLA(view_as_complex, fp32)
130+ # KERNEL_XLA(cholesky, fp32)
131+ # KERNEL_XLA(cholesky_inverse, fp32)
132+ # KERNEL_XLA(cholesky_solve, fp32)
133+ # KERNEL_XLA(inverse, fp32)
134+ # // KERNEL_XLA(lu_solve, fp32)
135+ # // KERNEL_XLA(orgqr, fp32)
136+ # // KERNEL_XLA(ormqr, fp32)
137+ # // KERNEL_XLA(pinverse, fp32)
138+ # KERNEL_XLA(reflection_pad1d, fp32)
139+ # KERNEL_XLA(reflection_pad2d, fp32)
140+ # KERNEL_XLA(replication_pad1d, fp32)
141+ # KERNEL_XLA(replication_pad2d, fp32)
142+ # KERNEL_XLA(replication_pad3d, fp32)
143+ # KERNEL_XLA(mse_loss, fp32)
144+ # KERNEL_XLA(cosine_embedding_loss, fp32)
145+ # KERNEL_XLA(nll_loss, fp32)
146+ # KERNEL_XLA(nll_loss2d, fp32)
147+ # KERNEL_XLA(hinge_embedding_loss, fp32)
148+ # // KERNEL_XLA(poisson_nll_loss, fp32)
149+ # KERNEL_XLA(smooth_l1_loss, fp32)
150+ # KERNEL_XLA(cross_entropy_loss, fp32)
151+ # KERNEL_XLA(l1_loss, fp32)
152+ # // KERNEL_XLA(huber_loss, fp32)
153+ # KERNEL_XLA(margin_ranking_loss, fp32)
154+ # KERNEL_XLA(soft_margin_loss, fp32)
155+ # KERNEL_XLA(triplet_margin_loss, fp32)
156+ # KERNEL_XLA(multi_margin_loss, fp32)
157+ # KERNEL_XLA2(ctc_loss, IntList, fp32)
158+ # KERNEL_XLA2(ctc_loss, Tensor, fp32)
159+ # KERNEL_XLA(kl_div, fp32)
160+ # KERNEL_XLA(multilabel_margin_loss, fp32)
161+ # KERNEL_XLA(binary_cross_entropy_with_logits, fp32)
162+ # // KERNEL_XLA(fft_fft, fp32)
163+ # // KERNEL_XLA(fft_ifft, fp32)
164+ # // KERNEL_XLA(fft_fft2, fp32)
165+ # // KERNEL_XLA(fft_ifft2, fp32)
166+ # // KERNEL_XLA(fft_fftn, fp32)
167+ # // KERNEL_XLA(fft_ifftn, fp32)
168+ # // KERNEL_XLA(fft_rfft, fp32)
169+ # // KERNEL_XLA(fft_irfft, fp32)
170+ # // KERNEL_XLA(fft_rfft2, fp32)
171+ # // KERNEL_XLA(fft_irfft2, fp32)
172+ # // KERNEL_XLA(fft_rfftn, fp32)
173+ # // KERNEL_XLA(fft_irfftn, fp32)
174+ # // KERNEL_XLA(fft_hfft, fp32)
175+ # // KERNEL_XLA(fft_ihfft, fp32)
176+ # // KERNEL_XLA(linalg_cond, fp32)
177+ # // KERNEL_XLA2(linalg_cond, p_str, fp32)
178+ # // KERNEL_XLA(linalg_matrix_rank, fp32)
179+ # // KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32)
180+ # // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32)
181+ # // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32)
182+ # // KERNEL_XLA(linalg_solve, fp32)
183+ # // KERNEL_XLA(linalg_cholesky, fp32)
184+ # // KERNEL_XLA(linalg_svdvals, fp32)
185+ # // KERNEL_XLA(linalg_eigvals, fp32)
186+ # // KERNEL_XLA(linalg_eigvalsh, fp32)
187+ # // KERNEL_XLA(linalg_inv, fp32)
188+ # // KERNEL_XLA(linalg_householder_product, fp32)
189+ # // KERNEL_XLA(linalg_tensorinv, fp32)
190+ # // KERNEL_XLA(linalg_tensorsolve, fp32)
191+ # // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32)
192+ # // KERNEL_XLA(geqrf, fp32)
193+ # // KERNEL_XLA(_lu_with_info, fp32)
194+ # KERNEL_XLA(qr, fp32)
195+ # KERNEL_XLA(svd, fp32)
196+ # KERNEL_XLA(triangular_solve, fp32)
197+ # KERNEL_XLA(multilabel_margin_loss_forward, fp32)
198+ # // KERNEL_XLA(linalg_qr, fp32)
199+ # // KERNEL_XLA(linalg_cholesky_ex, fp32)
200+ # KERNEL_XLA(linalg_svd, fp32)
201+ # // KERNEL_XLA(linalg_eig, fp32)
202+ # // KERNEL_XLA(linalg_eigh, fp32)
203+ # // KERNEL_XLA(linalg_lstsq, fp32)
204+ # KERNEL_XLA(linalg_inv_ex, fp32)
134205
135- // promote
136- KERNEL_XLA (stack , promote )
137- KERNEL_XLA (cat , promote )
138- KERNEL_XLA (index_copy , promote )
139- KERNEL_XLA2 (index_copy , dimname , promote )
206+ # // promote
207+ # KERNEL_XLA(stack, promote)
208+ # KERNEL_XLA(cat, promote)
209+ # KERNEL_XLA(index_copy, promote)
210+ # KERNEL_XLA2(index_copy, dimname, promote)
0 commit comments