1+ import torch
2+ import torch ._C
3+ from torch .utils import _pytree as pytree
4+
5+ def call_with_next_key (op , args , kwargs ):
6+ return op (* args , ** kwargs )
7+
8+ target_precision = torch .bfloat16
9+
10+ def lower_precision_fp (op ):
11+ def inner (* args , ** kwargs ):
12+ target_precision = torch .get_autocast_dtype ('privateuseone' )
13+ autocast_keyset = torch ._C .DispatchKeySet (torch ._C .DispatchKey .AutocastPrivateUse1 )
14+ with torch ._C ._ExcludeDispatchKeyGuard (autocast_keyset ):
15+ is_float_tensor = lambda a : isinstance (a , torch .Tensor ) and a .is_floating_point ()
16+ args , kwargs = pytree .tree_map_only (
17+ is_float_tensor ,
18+ lambda x : x .to (target_precision ),
19+ (args , kwargs ))
20+ return op (* args , ** kwargs )
21+ return inner
22+
23+
24+ lib = torch .library .Library ('aten' , 'FRAGMENT' )
25+ my_lib = torch .library .Library ('_' , 'IMPL' , 'AutocastPrivateUse1' )
26+ my_lib .fallback (torch .library .fallthrough_kernel )
27+
28+
29+ for op in [torch .ops .aten .conv1d .default ,
30+ torch .ops .aten .conv1d .padding ,
31+ torch .ops .aten .conv2d .default ,
32+ torch .ops .aten .conv2d .padding ,
33+ torch .ops .aten .conv3d .default ,
34+ torch .ops .aten .bmm .default ,
35+ torch .ops .aten .mm .default ,
36+ torch .ops .aten .baddbmm .default ,
37+ torch .ops .aten .addmm .default ,
38+ torch .ops .aten .addbmm .default ,
39+ torch .ops .aten .linear .default ,
40+ torch .ops .aten .matmul .default ,
41+ torch .ops .aten .conv_tbc .default ,
42+ torch .ops .aten .conv_transpose1d .default ,
43+ torch .ops .aten .conv_transpose2d .input ,
44+ torch .ops .aten .conv_transpose3d .input ,
45+ torch .ops .aten .prelu .default ,
46+ torch .ops .aten .relu .default ,
47+ torch .ops .aten .max_pool2d .default ,
48+ torch .ops .aten .einsum .default ,
49+ ]:
50+ lib .impl (op .name (), lower_precision_fp (op ), "AutocastPrivateUse1" , with_keyset = False )
51+
52+ # https://github.com/pytorch/xla/blob/20899c7258680a36cd3bec1c820e8a52c16a4bbf/torch_xla/csrc/autocast_mode.cpp#L29
53+ # enum class CastPolicy : uint8_t {
54+ # lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
55+ # // running the op. Currently, lower_precision_fp is
56+ # // fp16 for AutocastCUDA, and is defined by user
57+ # // (default bf16) for AutocastCPU or other device.
58+ # fp32, // Cast all inputs to at::kFloat before running the op.
59+ # fp32_set_opt_dtype, // Treats functions (like softmax) that
60+ # // 1. we'd like to run in fp32 and
61+ # // 2. have a std::optional<ScalarType> arg that controls
62+ # // the output type.
63+ # // fp32_set_opt_dtype wrappers' policy is: if the output
64+ # // type is already set, don't touch it, otherwise, set
65+ # // it to at::kFloat.
66+ # fp32_append_dtype, // Treats functions (like norm) that
67+ # // 1. we'd like to run in fp32 and
68+ # // 2. have some overloads that accept an output type and
69+ # // other overloads that don't.
70+ # // fp32_append_dtype wrappers wrap the overloads that don't
71+ # // have an output dtype.
72+ # // The wrapper policy is: append at::kFloat to the args,
73+ # // and redispatch to the type-aware overload.
74+ # promote, // Run in the widest dtype among several args.
75+ # };
76+ # TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
77+ # // lower_precision_fp cast policy
78+ # KERNEL_XLA(conv1d, lower_precision_fp)
79+ # KERNEL_XLA2(conv1d, padding, lower_precision_fp)
80+ # KERNEL_XLA(conv2d, lower_precision_fp)
81+ # KERNEL_XLA2(conv2d, padding, lower_precision_fp)
82+ # KERNEL_XLA(conv3d, lower_precision_fp)
83+ # KERNEL_XLA2(conv3d, padding, lower_precision_fp)
84+ # KERNEL_XLA(bmm, lower_precision_fp)
85+ # KERNEL_XLA(mm, lower_precision_fp)
86+ # KERNEL_XLA(baddbmm, lower_precision_fp)
87+ # KERNEL_XLA(addmm, lower_precision_fp)
88+ # KERNEL_XLA(addbmm, lower_precision_fp)
89+ # KERNEL_XLA(linear, lower_precision_fp)
90+ # KERNEL_XLA(matmul, lower_precision_fp)
91+ # KERNEL_XLA(conv_tbc, lower_precision_fp)
92+ # KERNEL_XLA(conv_transpose1d, lower_precision_fp)
93+ # KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp)
94+ # KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp)
95+ # KERNEL_XLA(prelu, lower_precision_fp)
96+ # KERNEL_XLA(relu, lower_precision_fp)
97+ # KERNEL_XLA(max_pool2d, lower_precision_fp)
98+ # KERNEL_XLA(einsum, lower_precision_fp)
99+ # // Disable `scaled_dot_product_attention` for now since it causes
100+ # // undefined symbol with official torch whl.
101+ # // KERNEL_XLA(scaled_dot_product_attention, lower_precision_fp)
102+
103+ # // fp32 cast policy
104+ # // Commented out ops are included in the AutoCastCPU Policy,
105+ # // but not lowered. Enable if op is lowered.
106+ # KERNEL_XLA(batch_norm, fp32)
107+ # KERNEL_XLA(_softmax, fp32)
108+ # KERNEL_XLA2(softmax, int, fp32)
109+ # KERNEL_XLA2(softmax, Dimname, fp32)
110+ # KERNEL_XLA2(log_softmax, int, fp32)
111+ # KERNEL_XLA2(log_softmax, Dimname, fp32)
112+ # KERNEL_XLA(binary_cross_entropy, fp32)
113+ # // KERNEL_XLA(grid_sampler, fp32)
114+ # // KERNEL_XLA(polar, fp32)
115+ # KERNEL_XLA2(pow, Tensor_Scalar, fp32)
116+ # KERNEL_XLA(prod, fp32)
117+ # KERNEL_XLA2(prod, dim_int, fp32)
118+ # KERNEL_XLA2(prod, dim_Dimname, fp32)
119+ # // KERNEL_XLA(quantile, fp32)
120+ # // KERNEL_XLA2(quantile, scalar, fp32)
121+ # // KERNEL_XLA(nanquantile, fp32)
122+ # // KERNEL_XLA2(nanquantile, scalar, fp32)
123+ # // KERNEL_XLA(stft, fp32)
124+ # // KERNEL_XLA2(stft, center, fp32)
125+ # KERNEL_XLA(cdist, fp32)
126+ # // KERNEL_XLA(grid_sampler_2d, fp32)
127+ # // KERNEL_XLA(grid_sampler_3d, fp32)
128+ # KERNEL_XLA(trace, fp32)
129+ # // KERNEL_XLA(view_as_complex, fp32)
130+ # KERNEL_XLA(cholesky, fp32)
131+ # KERNEL_XLA(cholesky_inverse, fp32)
132+ # KERNEL_XLA(cholesky_solve, fp32)
133+ # KERNEL_XLA(inverse, fp32)
134+ # // KERNEL_XLA(lu_solve, fp32)
135+ # // KERNEL_XLA(orgqr, fp32)
136+ # // KERNEL_XLA(ormqr, fp32)
137+ # // KERNEL_XLA(pinverse, fp32)
138+ # KERNEL_XLA(reflection_pad1d, fp32)
139+ # KERNEL_XLA(reflection_pad2d, fp32)
140+ # KERNEL_XLA(replication_pad1d, fp32)
141+ # KERNEL_XLA(replication_pad2d, fp32)
142+ # KERNEL_XLA(replication_pad3d, fp32)
143+ # KERNEL_XLA(mse_loss, fp32)
144+ # KERNEL_XLA(cosine_embedding_loss, fp32)
145+ # KERNEL_XLA(nll_loss, fp32)
146+ # KERNEL_XLA(nll_loss2d, fp32)
147+ # KERNEL_XLA(hinge_embedding_loss, fp32)
148+ # // KERNEL_XLA(poisson_nll_loss, fp32)
149+ # KERNEL_XLA(smooth_l1_loss, fp32)
150+ # KERNEL_XLA(cross_entropy_loss, fp32)
151+ # KERNEL_XLA(l1_loss, fp32)
152+ # // KERNEL_XLA(huber_loss, fp32)
153+ # KERNEL_XLA(margin_ranking_loss, fp32)
154+ # KERNEL_XLA(soft_margin_loss, fp32)
155+ # KERNEL_XLA(triplet_margin_loss, fp32)
156+ # KERNEL_XLA(multi_margin_loss, fp32)
157+ # KERNEL_XLA2(ctc_loss, IntList, fp32)
158+ # KERNEL_XLA2(ctc_loss, Tensor, fp32)
159+ # KERNEL_XLA(kl_div, fp32)
160+ # KERNEL_XLA(multilabel_margin_loss, fp32)
161+ # KERNEL_XLA(binary_cross_entropy_with_logits, fp32)
162+ # // KERNEL_XLA(fft_fft, fp32)
163+ # // KERNEL_XLA(fft_ifft, fp32)
164+ # // KERNEL_XLA(fft_fft2, fp32)
165+ # // KERNEL_XLA(fft_ifft2, fp32)
166+ # // KERNEL_XLA(fft_fftn, fp32)
167+ # // KERNEL_XLA(fft_ifftn, fp32)
168+ # // KERNEL_XLA(fft_rfft, fp32)
169+ # // KERNEL_XLA(fft_irfft, fp32)
170+ # // KERNEL_XLA(fft_rfft2, fp32)
171+ # // KERNEL_XLA(fft_irfft2, fp32)
172+ # // KERNEL_XLA(fft_rfftn, fp32)
173+ # // KERNEL_XLA(fft_irfftn, fp32)
174+ # // KERNEL_XLA(fft_hfft, fp32)
175+ # // KERNEL_XLA(fft_ihfft, fp32)
176+ # // KERNEL_XLA(linalg_cond, fp32)
177+ # // KERNEL_XLA2(linalg_cond, p_str, fp32)
178+ # // KERNEL_XLA(linalg_matrix_rank, fp32)
179+ # // KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32)
180+ # // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32)
181+ # // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32)
182+ # // KERNEL_XLA(linalg_solve, fp32)
183+ # // KERNEL_XLA(linalg_cholesky, fp32)
184+ # // KERNEL_XLA(linalg_svdvals, fp32)
185+ # // KERNEL_XLA(linalg_eigvals, fp32)
186+ # // KERNEL_XLA(linalg_eigvalsh, fp32)
187+ # // KERNEL_XLA(linalg_inv, fp32)
188+ # // KERNEL_XLA(linalg_householder_product, fp32)
189+ # // KERNEL_XLA(linalg_tensorinv, fp32)
190+ # // KERNEL_XLA(linalg_tensorsolve, fp32)
191+ # // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32)
192+ # // KERNEL_XLA(geqrf, fp32)
193+ # // KERNEL_XLA(_lu_with_info, fp32)
194+ # KERNEL_XLA(qr, fp32)
195+ # KERNEL_XLA(svd, fp32)
196+ # KERNEL_XLA(triangular_solve, fp32)
197+ # KERNEL_XLA(multilabel_margin_loss_forward, fp32)
198+ # // KERNEL_XLA(linalg_qr, fp32)
199+ # // KERNEL_XLA(linalg_cholesky_ex, fp32)
200+ # KERNEL_XLA(linalg_svd, fp32)
201+ # // KERNEL_XLA(linalg_eig, fp32)
202+ # // KERNEL_XLA(linalg_eigh, fp32)
203+ # // KERNEL_XLA(linalg_lstsq, fp32)
204+ # KERNEL_XLA(linalg_inv_ex, fp32)
205+
206+ # // promote
207+ # KERNEL_XLA(stack, promote)
208+ # KERNEL_XLA(cat, promote)
209+ # KERNEL_XLA(index_copy, promote)
210+ # KERNEL_XLA2(index_copy, dimname, promote)
0 commit comments