1+ import  contextlib 
2+ import  enum 
3+ import  torch 
4+ from  torch .utils  import  _pytree  as  pytree 
5+ 
6+ # enum class CastPolicy : uint8_t { 
7+ #   lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before 
8+ #                           // running the op. Currently, lower_precision_fp is 
9+ #                           // fp16 for AutocastCUDA, and is defined by user 
10+ #                           // (default bf16) for AutocastCPU or other device. 
11+ #   fp32, // Cast all inputs to at::kFloat before running the op. 
12+ #   fp32_set_opt_dtype, // Treats functions (like softmax) that 
13+ #                       //  1. we'd like to run in fp32 and 
14+ #                       //  2. have a std::optional<ScalarType> arg that controls 
15+ #                       //  the output type. 
16+ #                       // fp32_set_opt_dtype wrappers' policy is: if the output 
17+ #                       // type is already set, don't touch it, otherwise, set 
18+ #                       // it to at::kFloat. 
19+ #   fp32_append_dtype, // Treats functions (like norm) that 
20+ #                      //  1. we'd like to run in fp32 and 
21+ #                      //  2. have some overloads that accept an output type and 
22+ #                      //  other overloads that don't. 
23+ #                      // fp32_append_dtype wrappers wrap the overloads that don't 
24+ #                      // have an output dtype. 
25+ #                      // The wrapper policy is:  append at::kFloat to the args, 
26+ #                      // and redispatch to the type-aware overload. 
27+ #   promote, // Run in the widest dtype among several args. 
28+ # }; 
29+ class  CastPolicy (enum .Enum ):
30+   LOWER_PRECISION_FP  =  0 
31+   FP32  =  1 
32+   FP32_SET_OPT_DTYPE  =  2 
33+   FP32_APPEND_DTYPE  =  3 
34+   PROMOTE  =  4 
35+ 
36+ 
37+ def  execute_policy (policy , args , kwargs , target_lower_fp ):
38+   def  is_float (a ):
39+     return  isinstance (a , torch .Tensor ) and  a .is_floating_point ()
40+   match  policy :
41+     case  CastPolicy .LOWER_PRECISION_FP :
42+       return  pytree .tree_map_only (is_float , lambda  a : a .to (target_lower_fp ), (args , kwargs ))
43+     case  CastPolicy .FP32 :
44+       return  pytree .tree_map_only (is_float , lambda  a : a .to (torch .float32 ), (args , kwargs ))
45+     case  CastPolicy .PROMOTE :
46+       dtypes  =  set (a .dtype  for  a  in  args )
47+       widest  =  max ((dtype .itemsize , dtype ) for  dtype  in  dtypes )[1 ]
48+       return  pytree .tree_map_only (is_float , lambda  a : a .to (widest ), (args , kwargs ))
49+     case  _:
50+       raise  AssertionError (f'Policy { policy }  )
51+ 
52+       
53+ @contextlib .contextmanager  
54+ def  autocast (device , dtype = torch .bfloat16 , env = None ):
55+   del  device 
56+   if  env  is  None :
57+     import  torchax 
58+     env  =  torchax .default_env ()
59+   env .autocast_dtype , old  =  dtype , env .autocast_dtype 
60+   yield 
61+   env .autocast_dtype  =  old 
62+ 
63+ 
64+ # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327   
65+ autocast_policy  =  {
66+   torch .ops .aten .conv1d .default : CastPolicy .LOWER_PRECISION_FP ,
67+   torch .ops .aten .conv1d .padding : CastPolicy .LOWER_PRECISION_FP ,
68+   torch .ops .aten .conv2d .default : CastPolicy .LOWER_PRECISION_FP ,
69+   torch .ops .aten .conv2d .padding : CastPolicy .LOWER_PRECISION_FP ,
70+   torch .ops .aten .conv3d .default : CastPolicy .LOWER_PRECISION_FP ,
71+   torch .ops .aten .conv3d .padding : CastPolicy .LOWER_PRECISION_FP ,
72+   torch .ops .aten .bmm .default : CastPolicy .LOWER_PRECISION_FP ,
73+   torch .ops .aten .mm .default : CastPolicy .LOWER_PRECISION_FP ,
74+   torch .ops .aten .linalg_vecdot .default : CastPolicy .LOWER_PRECISION_FP ,
75+   torch .ops .aten .baddbmm .default : CastPolicy .LOWER_PRECISION_FP ,
76+   torch .ops .aten .addmm .default : CastPolicy .LOWER_PRECISION_FP ,
77+   torch .ops .aten ._addmm_activation .default : CastPolicy .LOWER_PRECISION_FP ,
78+   torch .ops .aten .addbmm .default : CastPolicy .LOWER_PRECISION_FP ,
79+   torch .ops .aten .linear .default : CastPolicy .LOWER_PRECISION_FP ,
80+   torch .ops .aten ._convolution .deprecated : CastPolicy .LOWER_PRECISION_FP ,
81+   torch .ops .aten .matmul .default : CastPolicy .LOWER_PRECISION_FP ,
82+   torch .ops .aten .conv_tbc .default : CastPolicy .LOWER_PRECISION_FP ,
83+   torch .ops .aten .mkldnn_rnn_layer .default : CastPolicy .LOWER_PRECISION_FP ,
84+   torch .ops .aten .conv_transpose1d .default : CastPolicy .LOWER_PRECISION_FP ,
85+   torch .ops .aten .conv_transpose2d .input : CastPolicy .LOWER_PRECISION_FP ,
86+   torch .ops .aten .conv_transpose3d .input : CastPolicy .LOWER_PRECISION_FP ,
87+   torch .ops .aten .prelu .default : CastPolicy .LOWER_PRECISION_FP ,
88+   torch .ops .aten .scaled_dot_product_attention .default : CastPolicy .LOWER_PRECISION_FP ,
89+   torch .ops .aten ._native_multi_head_attention .default : CastPolicy .LOWER_PRECISION_FP ,
90+ 
91+   # fp32 cast policy 
92+   torch .ops .aten .avg_pool3d .default : CastPolicy .FP32 ,
93+   torch .ops .aten .binary_cross_entropy .default : CastPolicy .FP32 ,
94+   torch .ops .aten .grid_sampler .default : CastPolicy .FP32 ,
95+   torch .ops .aten .polar .default : CastPolicy .FP32 ,
96+   torch .ops .aten .prod .default : CastPolicy .FP32 ,
97+   torch .ops .aten .prod .dim_int : CastPolicy .FP32 ,
98+   torch .ops .aten .prod .dim_Dimname : CastPolicy .FP32 ,
99+   torch .ops .aten .quantile .default : CastPolicy .FP32 ,
100+   torch .ops .aten .quantile .scalar : CastPolicy .FP32 ,
101+   torch .ops .aten .nanquantile .default : CastPolicy .FP32 ,
102+   torch .ops .aten .nanquantile .scalar : CastPolicy .FP32 ,
103+   torch .ops .aten .stft .default : CastPolicy .FP32 ,
104+   torch .ops .aten .stft .center : CastPolicy .FP32 ,
105+   torch .ops .aten .cdist .default : CastPolicy .FP32 ,
106+   torch .ops .aten .grid_sampler_2d .default : CastPolicy .FP32 ,
107+   torch .ops .aten ._grid_sampler_2d_cpu_fallback .default : CastPolicy .FP32 ,
108+   torch .ops .aten .grid_sampler_3d .default : CastPolicy .FP32 ,
109+   torch .ops .aten .trace .default : CastPolicy .FP32 ,
110+   torch .ops .aten .view_as_complex .default : CastPolicy .FP32 ,
111+   torch .ops .aten .cholesky .default : CastPolicy .FP32 ,
112+   torch .ops .aten .cholesky_inverse .default : CastPolicy .FP32 ,
113+   torch .ops .aten .cholesky_solve .default : CastPolicy .FP32 ,
114+   torch .ops .aten .inverse .default : CastPolicy .FP32 ,
115+   torch .ops .aten .lu_solve .default : CastPolicy .FP32 ,
116+   torch .ops .aten .orgqr .default : CastPolicy .FP32 ,
117+   torch .ops .aten .ormqr .default : CastPolicy .FP32 ,
118+   torch .ops .aten .pinverse .default : CastPolicy .FP32 ,
119+   torch .ops .aten .max_pool3d .default : CastPolicy .FP32 ,
120+   torch .ops .aten .max_unpool2d .default : CastPolicy .FP32 ,
121+   torch .ops .aten .max_unpool3d .default : CastPolicy .FP32 ,
122+   torch .ops .aten .adaptive_avg_pool3d .default : CastPolicy .FP32 ,
123+   torch .ops .aten .reflection_pad1d .default : CastPolicy .FP32 ,
124+   torch .ops .aten .reflection_pad2d .default : CastPolicy .FP32 ,
125+   torch .ops .aten .replication_pad1d .default : CastPolicy .FP32 ,
126+   torch .ops .aten .replication_pad2d .default : CastPolicy .FP32 ,
127+   torch .ops .aten .replication_pad3d .default : CastPolicy .FP32 ,
128+   torch .ops .aten .mse_loss .default : CastPolicy .FP32 ,
129+   torch .ops .aten .cosine_embedding_loss .default : CastPolicy .FP32 ,
130+   torch .ops .aten .nll_loss .default : CastPolicy .FP32 ,
131+   torch .ops .aten .nll_loss2d .default : CastPolicy .FP32 ,
132+   torch .ops .aten .hinge_embedding_loss .default : CastPolicy .FP32 ,
133+   torch .ops .aten .poisson_nll_loss .default : CastPolicy .FP32 ,
134+   torch .ops .aten .smooth_l1_loss .default : CastPolicy .FP32 ,
135+   torch .ops .aten .cross_entropy_loss .default : CastPolicy .FP32 ,
136+   torch .ops .aten .l1_loss .default : CastPolicy .FP32 ,
137+   torch .ops .aten .huber_loss .default : CastPolicy .FP32 ,
138+   torch .ops .aten .margin_ranking_loss .default : CastPolicy .FP32 ,
139+   torch .ops .aten .soft_margin_loss .default : CastPolicy .FP32 ,
140+   torch .ops .aten .triplet_margin_loss .default : CastPolicy .FP32 ,
141+   torch .ops .aten .multi_margin_loss .default : CastPolicy .FP32 ,
142+   torch .ops .aten .ctc_loss .IntList : CastPolicy .FP32 ,
143+   torch .ops .aten .ctc_loss .Tensor : CastPolicy .FP32 ,
144+   torch .ops .aten .kl_div .default : CastPolicy .FP32 ,
145+   torch .ops .aten .multilabel_margin_loss .default : CastPolicy .FP32 ,
146+   torch .ops .aten .binary_cross_entropy_with_logits .default : CastPolicy .FP32 ,
147+   torch .ops .aten .fft_fft .default : CastPolicy .FP32 ,
148+   torch .ops .aten .fft_ifft .default : CastPolicy .FP32 ,
149+   torch .ops .aten .fft_fft2 .default : CastPolicy .FP32 ,
150+   torch .ops .aten .fft_ifft2 .default : CastPolicy .FP32 ,
151+   torch .ops .aten .fft_fftn .default : CastPolicy .FP32 ,
152+   torch .ops .aten .fft_ifftn .default : CastPolicy .FP32 ,
153+   torch .ops .aten .fft_rfft .default : CastPolicy .FP32 ,
154+   torch .ops .aten .fft_irfft .default : CastPolicy .FP32 ,
155+   torch .ops .aten .fft_rfft2 .default : CastPolicy .FP32 ,
156+   torch .ops .aten .fft_irfft2 .default : CastPolicy .FP32 ,
157+   torch .ops .aten .fft_rfftn .default : CastPolicy .FP32 ,
158+   torch .ops .aten .fft_irfftn .default : CastPolicy .FP32 ,
159+   torch .ops .aten .fft_hfft .default : CastPolicy .FP32 ,
160+   torch .ops .aten .fft_ihfft .default : CastPolicy .FP32 ,
161+   torch .ops .aten .linalg_cond .default : CastPolicy .FP32 ,
162+   torch .ops .aten .linalg_cond .p_str : CastPolicy .FP32 ,
163+   torch .ops .aten .linalg_matrix_rank .default : CastPolicy .FP32 ,
164+   torch .ops .aten .linalg_matrix_rank .tol_tensor : CastPolicy .FP32 ,
165+   torch .ops .aten .linalg_matrix_rank .atol_rtol_tensor : CastPolicy .FP32 ,
166+   torch .ops .aten .linalg_matrix_rank .atol_rtol_float : CastPolicy .FP32 ,
167+   torch .ops .aten .linalg_solve .default : CastPolicy .FP32 ,
168+   torch .ops .aten .linalg_cholesky .default : CastPolicy .FP32 ,
169+   torch .ops .aten .linalg_svdvals .default : CastPolicy .FP32 ,
170+   torch .ops .aten .linalg_eigvals .default : CastPolicy .FP32 ,
171+   torch .ops .aten .linalg_eigvalsh .default : CastPolicy .FP32 ,
172+   torch .ops .aten .linalg_inv .default : CastPolicy .FP32 ,
173+   torch .ops .aten .linalg_householder_product .default : CastPolicy .FP32 ,
174+   torch .ops .aten .linalg_tensorinv .default : CastPolicy .FP32 ,
175+   torch .ops .aten .linalg_tensorsolve .default : CastPolicy .FP32 ,
176+   torch .ops .aten .fake_quantize_per_tensor_affine .default : CastPolicy .FP32 ,
177+   torch .ops .aten .geqrf .default : CastPolicy .FP32 ,
178+   torch .ops .aten ._lu_with_info .default : CastPolicy .FP32 ,
179+   torch .ops .aten .qr .default : CastPolicy .FP32 ,
180+   torch .ops .aten .svd .default : CastPolicy .FP32 ,
181+   torch .ops .aten .triangular_solve .default : CastPolicy .FP32 ,
182+   torch .ops .aten .fractional_max_pool2d .default : CastPolicy .FP32 ,
183+   torch .ops .aten .fractional_max_pool3d .default : CastPolicy .FP32 ,
184+   torch .ops .aten .adaptive_max_pool3d .default : CastPolicy .FP32 ,
185+   torch .ops .aten .multilabel_margin_loss_forward .default : CastPolicy .FP32 ,
186+   torch .ops .aten .linalg_qr .default : CastPolicy .FP32 ,
187+   torch .ops .aten .linalg_cholesky_ex .default : CastPolicy .FP32 ,
188+   torch .ops .aten .linalg_svd .default : CastPolicy .FP32 ,
189+   torch .ops .aten .linalg_eig .default : CastPolicy .FP32 ,
190+   torch .ops .aten .linalg_eigh .default : CastPolicy .FP32 ,
191+   torch .ops .aten .linalg_lstsq .default : CastPolicy .FP32 ,
192+   torch .ops .aten .linalg_inv_ex .default : CastPolicy .FP32 ,
193+ 
194+   # promote 
195+   torch .ops .aten .stack .default : CastPolicy .PROMOTE ,
196+   torch .ops .aten .cat .default : CastPolicy .PROMOTE ,
197+   torch .ops .aten .index_copy .default : CastPolicy .PROMOTE ,
198+   torch .ops .aten .index_copy .dimname : CastPolicy .PROMOTE ,
199+ }
0 commit comments