1+ # https://github.com/pytorch/xla/blob/20899c7258680a36cd3bec1c820e8a52c16a4bbf/torch_xla/csrc/autocast_mode.cpp#L29
2+
3+
4+
5+ TORCH_LIBRARY_IMPL (aten , AutocastXLA , m ) {
6+ // lower_precision_fp cast policy
7+ KERNEL_XLA (conv1d , lower_precision_fp )
8+ KERNEL_XLA2 (conv1d , padding , lower_precision_fp )
9+ KERNEL_XLA (conv2d , lower_precision_fp )
10+ KERNEL_XLA2 (conv2d , padding , lower_precision_fp )
11+ KERNEL_XLA (conv3d , lower_precision_fp )
12+ KERNEL_XLA2 (conv3d , padding , lower_precision_fp )
13+ KERNEL_XLA (bmm , lower_precision_fp )
14+ KERNEL_XLA (mm , lower_precision_fp )
15+ KERNEL_XLA (baddbmm , lower_precision_fp )
16+ KERNEL_XLA (addmm , lower_precision_fp )
17+ KERNEL_XLA (addbmm , lower_precision_fp )
18+ KERNEL_XLA (linear , lower_precision_fp )
19+ KERNEL_XLA (matmul , lower_precision_fp )
20+ KERNEL_XLA (conv_tbc , lower_precision_fp )
21+ KERNEL_XLA (conv_transpose1d , lower_precision_fp )
22+ KERNEL_XLA2 (conv_transpose2d , input , lower_precision_fp )
23+ KERNEL_XLA2 (conv_transpose3d , input , lower_precision_fp )
24+ KERNEL_XLA (prelu , lower_precision_fp )
25+ KERNEL_XLA (relu , lower_precision_fp )
26+ KERNEL_XLA (max_pool2d , lower_precision_fp )
27+ KERNEL_XLA (einsum , lower_precision_fp )
28+ // Disable `scaled_dot_product_attention` for now since it causes
29+ // undefined symbol with official torch whl .
30+ // KERNEL_XLA (scaled_dot_product_attention , lower_precision_fp )
31+
32+ // fp32 cast policy
33+ // Commented out ops are included in the AutoCastCPU Policy ,
34+ // but not lowered . Enable if op is lowered .
35+ KERNEL_XLA (batch_norm , fp32 )
36+ KERNEL_XLA (_softmax , fp32 )
37+ KERNEL_XLA2 (softmax , int , fp32 )
38+ KERNEL_XLA2 (softmax , Dimname , fp32 )
39+ KERNEL_XLA2 (log_softmax , int , fp32 )
40+ KERNEL_XLA2 (log_softmax , Dimname , fp32 )
41+ KERNEL_XLA (binary_cross_entropy , fp32 )
42+ // KERNEL_XLA (grid_sampler , fp32 )
43+ // KERNEL_XLA (polar , fp32 )
44+ KERNEL_XLA2 (pow , Tensor_Scalar , fp32 )
45+ KERNEL_XLA (prod , fp32 )
46+ KERNEL_XLA2 (prod , dim_int , fp32 )
47+ KERNEL_XLA2 (prod , dim_Dimname , fp32 )
48+ // KERNEL_XLA (quantile , fp32 )
49+ // KERNEL_XLA2 (quantile , scalar , fp32 )
50+ // KERNEL_XLA (nanquantile , fp32 )
51+ // KERNEL_XLA2 (nanquantile , scalar , fp32 )
52+ // KERNEL_XLA (stft , fp32 )
53+ // KERNEL_XLA2 (stft , center , fp32 )
54+ KERNEL_XLA (cdist , fp32 )
55+ // KERNEL_XLA (grid_sampler_2d , fp32 )
56+ // KERNEL_XLA (grid_sampler_3d , fp32 )
57+ KERNEL_XLA (trace , fp32 )
58+ // KERNEL_XLA (view_as_complex , fp32 )
59+ KERNEL_XLA (cholesky , fp32 )
60+ KERNEL_XLA (cholesky_inverse , fp32 )
61+ KERNEL_XLA (cholesky_solve , fp32 )
62+ KERNEL_XLA (inverse , fp32 )
63+ // KERNEL_XLA (lu_solve , fp32 )
64+ // KERNEL_XLA (orgqr , fp32 )
65+ // KERNEL_XLA (ormqr , fp32 )
66+ // KERNEL_XLA (pinverse , fp32 )
67+ KERNEL_XLA (reflection_pad1d , fp32 )
68+ KERNEL_XLA (reflection_pad2d , fp32 )
69+ KERNEL_XLA (replication_pad1d , fp32 )
70+ KERNEL_XLA (replication_pad2d , fp32 )
71+ KERNEL_XLA (replication_pad3d , fp32 )
72+ KERNEL_XLA (mse_loss , fp32 )
73+ KERNEL_XLA (cosine_embedding_loss , fp32 )
74+ KERNEL_XLA (nll_loss , fp32 )
75+ KERNEL_XLA (nll_loss2d , fp32 )
76+ KERNEL_XLA (hinge_embedding_loss , fp32 )
77+ // KERNEL_XLA (poisson_nll_loss , fp32 )
78+ KERNEL_XLA (smooth_l1_loss , fp32 )
79+ KERNEL_XLA (cross_entropy_loss , fp32 )
80+ KERNEL_XLA (l1_loss , fp32 )
81+ // KERNEL_XLA (huber_loss , fp32 )
82+ KERNEL_XLA (margin_ranking_loss , fp32 )
83+ KERNEL_XLA (soft_margin_loss , fp32 )
84+ KERNEL_XLA (triplet_margin_loss , fp32 )
85+ KERNEL_XLA (multi_margin_loss , fp32 )
86+ KERNEL_XLA2 (ctc_loss , IntList , fp32 )
87+ KERNEL_XLA2 (ctc_loss , Tensor , fp32 )
88+ KERNEL_XLA (kl_div , fp32 )
89+ KERNEL_XLA (multilabel_margin_loss , fp32 )
90+ KERNEL_XLA (binary_cross_entropy_with_logits , fp32 )
91+ // KERNEL_XLA (fft_fft , fp32 )
92+ // KERNEL_XLA (fft_ifft , fp32 )
93+ // KERNEL_XLA (fft_fft2 , fp32 )
94+ // KERNEL_XLA (fft_ifft2 , fp32 )
95+ // KERNEL_XLA (fft_fftn , fp32 )
96+ // KERNEL_XLA (fft_ifftn , fp32 )
97+ // KERNEL_XLA (fft_rfft , fp32 )
98+ // KERNEL_XLA (fft_irfft , fp32 )
99+ // KERNEL_XLA (fft_rfft2 , fp32 )
100+ // KERNEL_XLA (fft_irfft2 , fp32 )
101+ // KERNEL_XLA (fft_rfftn , fp32 )
102+ // KERNEL_XLA (fft_irfftn , fp32 )
103+ // KERNEL_XLA (fft_hfft , fp32 )
104+ // KERNEL_XLA (fft_ihfft , fp32 )
105+ // KERNEL_XLA (linalg_cond , fp32 )
106+ // KERNEL_XLA2 (linalg_cond , p_str , fp32 )
107+ // KERNEL_XLA (linalg_matrix_rank , fp32 )
108+ // KERNEL_XLA2 (linalg_matrix_rank , tol_tensor , fp32 )
109+ // KERNEL_XLA2 (linalg_matrix_rank , atol_rtol_tensor , fp32 )
110+ // KERNEL_XLA2 (linalg_matrix_rank , atol_rtol_float , fp32 )
111+ // KERNEL_XLA (linalg_solve , fp32 )
112+ // KERNEL_XLA (linalg_cholesky , fp32 )
113+ // KERNEL_XLA (linalg_svdvals , fp32 )
114+ // KERNEL_XLA (linalg_eigvals , fp32 )
115+ // KERNEL_XLA (linalg_eigvalsh , fp32 )
116+ // KERNEL_XLA (linalg_inv , fp32 )
117+ // KERNEL_XLA (linalg_householder_product , fp32 )
118+ // KERNEL_XLA (linalg_tensorinv , fp32 )
119+ // KERNEL_XLA (linalg_tensorsolve , fp32 )
120+ // KERNEL_XLA (fake_quantize_per_tensor_affine , fp32 )
121+ // KERNEL_XLA (geqrf , fp32 )
122+ // KERNEL_XLA (_lu_with_info , fp32 )
123+ KERNEL_XLA (qr , fp32 )
124+ KERNEL_XLA (svd , fp32 )
125+ KERNEL_XLA (triangular_solve , fp32 )
126+ KERNEL_XLA (multilabel_margin_loss_forward , fp32 )
127+ // KERNEL_XLA (linalg_qr , fp32 )
128+ // KERNEL_XLA (linalg_cholesky_ex , fp32 )
129+ KERNEL_XLA (linalg_svd , fp32 )
130+ // KERNEL_XLA (linalg_eig , fp32 )
131+ // KERNEL_XLA (linalg_eigh , fp32 )
132+ // KERNEL_XLA (linalg_lstsq , fp32 )
133+ KERNEL_XLA (linalg_inv_ex , fp32 )
134+
135+ // promote
136+ KERNEL_XLA (stack , promote )
137+ KERNEL_XLA (cat , promote )
138+ KERNEL_XLA (index_copy , promote )
139+ KERNEL_XLA2 (index_copy , dimname , promote )
0 commit comments