Skip to content

Commit 80896f3

Browse files
committed
checkpoint so far
1 parent d1c8fff commit 80896f3

File tree

4 files changed

+150
-6
lines changed

4 files changed

+150
-6
lines changed

torchax/torchax/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def disable_temporarily():
7878
enable_globally()
7979

8080

81-
torch.utils.rename_privateuse1_backend('jax')
81+
#torch.utils.rename_privateuse1_backend('jax')
8282
unsupported_dtype = [torch.quint8]
8383
torch.utils.generate_methods_for_privateuse1_backend(
8484
for_tensor=True,
@@ -89,7 +89,7 @@ def disable_temporarily():
8989
import jax
9090
import torchax.device_module
9191

92-
torch._register_device_module('jax', torchax.device_module)
92+
torch._register_device_module('privateuseone', torchax.device_module)
9393

9494

9595
def enable_accuracy_mode():

torchax/torchax/device_module.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,8 @@ def is_available():
2424

2525
def current_device():
2626
return 0
27+
28+
29+
import torch
30+
def get_amp_supported_dtype():
31+
return [torch.float16, torch.bfloat16]
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# https://github.com/pytorch/xla/blob/20899c7258680a36cd3bec1c820e8a52c16a4bbf/torch_xla/csrc/autocast_mode.cpp#L29
2+
3+
4+
5+
TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
6+
// lower_precision_fp cast policy
7+
KERNEL_XLA(conv1d, lower_precision_fp)
8+
KERNEL_XLA2(conv1d, padding, lower_precision_fp)
9+
KERNEL_XLA(conv2d, lower_precision_fp)
10+
KERNEL_XLA2(conv2d, padding, lower_precision_fp)
11+
KERNEL_XLA(conv3d, lower_precision_fp)
12+
KERNEL_XLA2(conv3d, padding, lower_precision_fp)
13+
KERNEL_XLA(bmm, lower_precision_fp)
14+
KERNEL_XLA(mm, lower_precision_fp)
15+
KERNEL_XLA(baddbmm, lower_precision_fp)
16+
KERNEL_XLA(addmm, lower_precision_fp)
17+
KERNEL_XLA(addbmm, lower_precision_fp)
18+
KERNEL_XLA(linear, lower_precision_fp)
19+
KERNEL_XLA(matmul, lower_precision_fp)
20+
KERNEL_XLA(conv_tbc, lower_precision_fp)
21+
KERNEL_XLA(conv_transpose1d, lower_precision_fp)
22+
KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp)
23+
KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp)
24+
KERNEL_XLA(prelu, lower_precision_fp)
25+
KERNEL_XLA(relu, lower_precision_fp)
26+
KERNEL_XLA(max_pool2d, lower_precision_fp)
27+
KERNEL_XLA(einsum, lower_precision_fp)
28+
// Disable `scaled_dot_product_attention` for now since it causes
29+
// undefined symbol with official torch whl.
30+
// KERNEL_XLA(scaled_dot_product_attention, lower_precision_fp)
31+
32+
// fp32 cast policy
33+
// Commented out ops are included in the AutoCastCPU Policy,
34+
// but not lowered. Enable if op is lowered.
35+
KERNEL_XLA(batch_norm, fp32)
36+
KERNEL_XLA(_softmax, fp32)
37+
KERNEL_XLA2(softmax, int, fp32)
38+
KERNEL_XLA2(softmax, Dimname, fp32)
39+
KERNEL_XLA2(log_softmax, int, fp32)
40+
KERNEL_XLA2(log_softmax, Dimname, fp32)
41+
KERNEL_XLA(binary_cross_entropy, fp32)
42+
// KERNEL_XLA(grid_sampler, fp32)
43+
// KERNEL_XLA(polar, fp32)
44+
KERNEL_XLA2(pow, Tensor_Scalar, fp32)
45+
KERNEL_XLA(prod, fp32)
46+
KERNEL_XLA2(prod, dim_int, fp32)
47+
KERNEL_XLA2(prod, dim_Dimname, fp32)
48+
// KERNEL_XLA(quantile, fp32)
49+
// KERNEL_XLA2(quantile, scalar, fp32)
50+
// KERNEL_XLA(nanquantile, fp32)
51+
// KERNEL_XLA2(nanquantile, scalar, fp32)
52+
// KERNEL_XLA(stft, fp32)
53+
// KERNEL_XLA2(stft, center, fp32)
54+
KERNEL_XLA(cdist, fp32)
55+
// KERNEL_XLA(grid_sampler_2d, fp32)
56+
// KERNEL_XLA(grid_sampler_3d, fp32)
57+
KERNEL_XLA(trace, fp32)
58+
// KERNEL_XLA(view_as_complex, fp32)
59+
KERNEL_XLA(cholesky, fp32)
60+
KERNEL_XLA(cholesky_inverse, fp32)
61+
KERNEL_XLA(cholesky_solve, fp32)
62+
KERNEL_XLA(inverse, fp32)
63+
// KERNEL_XLA(lu_solve, fp32)
64+
// KERNEL_XLA(orgqr, fp32)
65+
// KERNEL_XLA(ormqr, fp32)
66+
// KERNEL_XLA(pinverse, fp32)
67+
KERNEL_XLA(reflection_pad1d, fp32)
68+
KERNEL_XLA(reflection_pad2d, fp32)
69+
KERNEL_XLA(replication_pad1d, fp32)
70+
KERNEL_XLA(replication_pad2d, fp32)
71+
KERNEL_XLA(replication_pad3d, fp32)
72+
KERNEL_XLA(mse_loss, fp32)
73+
KERNEL_XLA(cosine_embedding_loss, fp32)
74+
KERNEL_XLA(nll_loss, fp32)
75+
KERNEL_XLA(nll_loss2d, fp32)
76+
KERNEL_XLA(hinge_embedding_loss, fp32)
77+
// KERNEL_XLA(poisson_nll_loss, fp32)
78+
KERNEL_XLA(smooth_l1_loss, fp32)
79+
KERNEL_XLA(cross_entropy_loss, fp32)
80+
KERNEL_XLA(l1_loss, fp32)
81+
// KERNEL_XLA(huber_loss, fp32)
82+
KERNEL_XLA(margin_ranking_loss, fp32)
83+
KERNEL_XLA(soft_margin_loss, fp32)
84+
KERNEL_XLA(triplet_margin_loss, fp32)
85+
KERNEL_XLA(multi_margin_loss, fp32)
86+
KERNEL_XLA2(ctc_loss, IntList, fp32)
87+
KERNEL_XLA2(ctc_loss, Tensor, fp32)
88+
KERNEL_XLA(kl_div, fp32)
89+
KERNEL_XLA(multilabel_margin_loss, fp32)
90+
KERNEL_XLA(binary_cross_entropy_with_logits, fp32)
91+
// KERNEL_XLA(fft_fft, fp32)
92+
// KERNEL_XLA(fft_ifft, fp32)
93+
// KERNEL_XLA(fft_fft2, fp32)
94+
// KERNEL_XLA(fft_ifft2, fp32)
95+
// KERNEL_XLA(fft_fftn, fp32)
96+
// KERNEL_XLA(fft_ifftn, fp32)
97+
// KERNEL_XLA(fft_rfft, fp32)
98+
// KERNEL_XLA(fft_irfft, fp32)
99+
// KERNEL_XLA(fft_rfft2, fp32)
100+
// KERNEL_XLA(fft_irfft2, fp32)
101+
// KERNEL_XLA(fft_rfftn, fp32)
102+
// KERNEL_XLA(fft_irfftn, fp32)
103+
// KERNEL_XLA(fft_hfft, fp32)
104+
// KERNEL_XLA(fft_ihfft, fp32)
105+
// KERNEL_XLA(linalg_cond, fp32)
106+
// KERNEL_XLA2(linalg_cond, p_str, fp32)
107+
// KERNEL_XLA(linalg_matrix_rank, fp32)
108+
// KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32)
109+
// KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32)
110+
// KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32)
111+
// KERNEL_XLA(linalg_solve, fp32)
112+
// KERNEL_XLA(linalg_cholesky, fp32)
113+
// KERNEL_XLA(linalg_svdvals, fp32)
114+
// KERNEL_XLA(linalg_eigvals, fp32)
115+
// KERNEL_XLA(linalg_eigvalsh, fp32)
116+
// KERNEL_XLA(linalg_inv, fp32)
117+
// KERNEL_XLA(linalg_householder_product, fp32)
118+
// KERNEL_XLA(linalg_tensorinv, fp32)
119+
// KERNEL_XLA(linalg_tensorsolve, fp32)
120+
// KERNEL_XLA(fake_quantize_per_tensor_affine, fp32)
121+
// KERNEL_XLA(geqrf, fp32)
122+
// KERNEL_XLA(_lu_with_info, fp32)
123+
KERNEL_XLA(qr, fp32)
124+
KERNEL_XLA(svd, fp32)
125+
KERNEL_XLA(triangular_solve, fp32)
126+
KERNEL_XLA(multilabel_margin_loss_forward, fp32)
127+
// KERNEL_XLA(linalg_qr, fp32)
128+
// KERNEL_XLA(linalg_cholesky_ex, fp32)
129+
KERNEL_XLA(linalg_svd, fp32)
130+
// KERNEL_XLA(linalg_eig, fp32)
131+
// KERNEL_XLA(linalg_eigh, fp32)
132+
// KERNEL_XLA(linalg_lstsq, fp32)
133+
KERNEL_XLA(linalg_inv_ex, fp32)
134+
135+
// promote
136+
KERNEL_XLA(stack, promote)
137+
KERNEL_XLA(cat, promote)
138+
KERNEL_XLA(index_copy, promote)
139+
KERNEL_XLA2(index_copy, dimname, promote)

torchax/torchax/tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __new__(cls, elem, env):
5959
cls,
6060
shape,
6161
dtype=dtype,
62-
device="meta",
62+
device="privateuseone:0",
6363
requires_grad=False,
6464
)
6565

@@ -134,9 +134,9 @@ def dtype(self):
134134
def dim(self):
135135
return self.ndim
136136

137-
@property
138-
def device(self):
139-
return torch.device("jax:0")
137+
# @property
138+
# def device(self):
139+
# return torch.device("jax:0")
140140

141141
@property
142142
def jax_device(self):

0 commit comments

Comments
 (0)