Skip to content

Commit fdf0387

Browse files
committed
Add autocast feature as torchax.amp.autocast.
This PR implements the 3 autocast policies that we use and wires them in the Environment. Wiring it through torch infrastructure so that torch.autocast also work is WIP in #9361
1 parent 55a7540 commit fdf0387

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

torchax/test/test_amp.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import unittest
2+
import jax
3+
import jax.numpy as jnp
4+
import torchax
5+
from torchax import interop
6+
import torch
7+
8+
9+
10+
class AutocastTest(unittest.TestCase):
11+
12+
def setUp(self):
13+
self.env = torchax.default_env()
14+
15+
16+
def test_auto_cast_ir(self):
17+
with self.env:
18+
with torchax.amp.autocast('jax', dtype=torch.bfloat16, env=self.env):
19+
a = jax.ShapeDtypeStruct((2,2), jnp.float32)
20+
b = jax.ShapeDtypeStruct((2,2), jnp.float32)
21+
ir_text = jax.jit(interop.jax_view(torch.matmul)).lower(a, b).as_text()
22+
self.assertIn('tensor<2x2xbf16>', ir_text)
23+
24+
def test_auto_cast_matmul(self):
25+
with self.env:
26+
a = torch.randn(2, 2, device='jax')
27+
b = torch.randn(2, 2, device='jax')
28+
with torchax.amp.autocast('jax', dtype=torch.bfloat16, env=self.env):
29+
c = a @ b
30+
31+
self.assertEqual(c.dtype, torch.bfloat16)
32+
33+
with torch.autocast('cpu', dtype=torch.bfloat16):
34+
c_cpu = a.cpu() @ b.cpu()
35+
36+
self.assertTrue(torch.allclose(c.cpu(), c_cpu))
37+
38+
39+
if __name__ == '__main__':
40+
unittest.main()

torchax/torchax/amp.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import contextlib
2+
import enum
3+
import torch
4+
from torch.utils import _pytree as pytree
5+
6+
# enum class CastPolicy : uint8_t {
7+
# lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
8+
# // running the op. Currently, lower_precision_fp is
9+
# // fp16 for AutocastCUDA, and is defined by user
10+
# // (default bf16) for AutocastCPU or other device.
11+
# fp32, // Cast all inputs to at::kFloat before running the op.
12+
# fp32_set_opt_dtype, // Treats functions (like softmax) that
13+
# // 1. we'd like to run in fp32 and
14+
# // 2. have a std::optional<ScalarType> arg that controls
15+
# // the output type.
16+
# // fp32_set_opt_dtype wrappers' policy is: if the output
17+
# // type is already set, don't touch it, otherwise, set
18+
# // it to at::kFloat.
19+
# fp32_append_dtype, // Treats functions (like norm) that
20+
# // 1. we'd like to run in fp32 and
21+
# // 2. have some overloads that accept an output type and
22+
# // other overloads that don't.
23+
# // fp32_append_dtype wrappers wrap the overloads that don't
24+
# // have an output dtype.
25+
# // The wrapper policy is: append at::kFloat to the args,
26+
# // and redispatch to the type-aware overload.
27+
# promote, // Run in the widest dtype among several args.
28+
# };
29+
class CastPolicy(enum.Enum):
30+
LOWER_PRECISION_FP = 0
31+
FP32 = 1
32+
FP32_SET_OPT_DTYPE = 2
33+
FP32_APPEND_DTYPE = 3
34+
PROMOTE = 4
35+
36+
37+
def execute_policy(policy, args, kwargs, target_lower_fp):
38+
def is_float(a):
39+
return isinstance(a, torch.Tensor) and a.is_floating_point()
40+
match policy:
41+
case CastPolicy.LOWER_PRECISION_FP:
42+
return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp), (args, kwargs))
43+
case CastPolicy.FP32:
44+
return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32), (args, kwargs))
45+
case CastPolicy.PROMOTE:
46+
dtypes = set(a.dtype for a in args)
47+
widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1]
48+
return pytree.tree_map_only(is_float, lambda a: a.to(widest), (args, kwargs))
49+
case _:
50+
raise AssertionError(f'Policy {policy} not implemented yet.')
51+
52+
53+
@contextlib.contextmanager
54+
def autocast(device, dtype=torch.bfloat16, env=None):
55+
del device
56+
if env is None:
57+
import torchax
58+
env = torchax.default_env()
59+
env.autocast_dtype, old = dtype, env.autocast_dtype
60+
yield
61+
env.autocast_dtype = old
62+
63+
64+
# https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
65+
autocast_policy = {
66+
torch.ops.aten.conv1d.default: CastPolicy.LOWER_PRECISION_FP,
67+
torch.ops.aten.conv1d.padding: CastPolicy.LOWER_PRECISION_FP,
68+
torch.ops.aten.conv2d.default: CastPolicy.LOWER_PRECISION_FP,
69+
torch.ops.aten.conv2d.padding: CastPolicy.LOWER_PRECISION_FP,
70+
torch.ops.aten.conv3d.default: CastPolicy.LOWER_PRECISION_FP,
71+
torch.ops.aten.conv3d.padding: CastPolicy.LOWER_PRECISION_FP,
72+
torch.ops.aten.bmm.default: CastPolicy.LOWER_PRECISION_FP,
73+
torch.ops.aten.mm.default: CastPolicy.LOWER_PRECISION_FP,
74+
torch.ops.aten.linalg_vecdot.default: CastPolicy.LOWER_PRECISION_FP,
75+
torch.ops.aten.baddbmm.default: CastPolicy.LOWER_PRECISION_FP,
76+
torch.ops.aten.addmm.default: CastPolicy.LOWER_PRECISION_FP,
77+
torch.ops.aten._addmm_activation.default: CastPolicy.LOWER_PRECISION_FP,
78+
torch.ops.aten.addbmm.default: CastPolicy.LOWER_PRECISION_FP,
79+
torch.ops.aten.linear.default: CastPolicy.LOWER_PRECISION_FP,
80+
torch.ops.aten._convolution.deprecated: CastPolicy.LOWER_PRECISION_FP,
81+
torch.ops.aten.matmul.default: CastPolicy.LOWER_PRECISION_FP,
82+
torch.ops.aten.conv_tbc.default: CastPolicy.LOWER_PRECISION_FP,
83+
torch.ops.aten.mkldnn_rnn_layer.default: CastPolicy.LOWER_PRECISION_FP,
84+
torch.ops.aten.conv_transpose1d.default: CastPolicy.LOWER_PRECISION_FP,
85+
torch.ops.aten.conv_transpose2d.input: CastPolicy.LOWER_PRECISION_FP,
86+
torch.ops.aten.conv_transpose3d.input: CastPolicy.LOWER_PRECISION_FP,
87+
torch.ops.aten.prelu.default: CastPolicy.LOWER_PRECISION_FP,
88+
torch.ops.aten.scaled_dot_product_attention.default: CastPolicy.LOWER_PRECISION_FP,
89+
torch.ops.aten._native_multi_head_attention.default: CastPolicy.LOWER_PRECISION_FP,
90+
91+
# fp32 cast policy
92+
torch.ops.aten.avg_pool3d.default: CastPolicy.FP32,
93+
torch.ops.aten.binary_cross_entropy.default: CastPolicy.FP32,
94+
torch.ops.aten.grid_sampler.default: CastPolicy.FP32,
95+
torch.ops.aten.polar.default: CastPolicy.FP32,
96+
torch.ops.aten.prod.default: CastPolicy.FP32,
97+
torch.ops.aten.prod.dim_int: CastPolicy.FP32,
98+
torch.ops.aten.prod.dim_Dimname: CastPolicy.FP32,
99+
torch.ops.aten.quantile.default: CastPolicy.FP32,
100+
torch.ops.aten.quantile.scalar: CastPolicy.FP32,
101+
torch.ops.aten.nanquantile.default: CastPolicy.FP32,
102+
torch.ops.aten.nanquantile.scalar: CastPolicy.FP32,
103+
torch.ops.aten.stft.default: CastPolicy.FP32,
104+
torch.ops.aten.stft.center: CastPolicy.FP32,
105+
torch.ops.aten.cdist.default: CastPolicy.FP32,
106+
torch.ops.aten.grid_sampler_2d.default: CastPolicy.FP32,
107+
torch.ops.aten._grid_sampler_2d_cpu_fallback.default: CastPolicy.FP32,
108+
torch.ops.aten.grid_sampler_3d.default: CastPolicy.FP32,
109+
torch.ops.aten.trace.default: CastPolicy.FP32,
110+
torch.ops.aten.view_as_complex.default: CastPolicy.FP32,
111+
torch.ops.aten.cholesky.default: CastPolicy.FP32,
112+
torch.ops.aten.cholesky_inverse.default: CastPolicy.FP32,
113+
torch.ops.aten.cholesky_solve.default: CastPolicy.FP32,
114+
torch.ops.aten.inverse.default: CastPolicy.FP32,
115+
torch.ops.aten.lu_solve.default: CastPolicy.FP32,
116+
torch.ops.aten.orgqr.default: CastPolicy.FP32,
117+
torch.ops.aten.ormqr.default: CastPolicy.FP32,
118+
torch.ops.aten.pinverse.default: CastPolicy.FP32,
119+
torch.ops.aten.max_pool3d.default: CastPolicy.FP32,
120+
torch.ops.aten.max_unpool2d.default: CastPolicy.FP32,
121+
torch.ops.aten.max_unpool3d.default: CastPolicy.FP32,
122+
torch.ops.aten.adaptive_avg_pool3d.default: CastPolicy.FP32,
123+
torch.ops.aten.reflection_pad1d.default: CastPolicy.FP32,
124+
torch.ops.aten.reflection_pad2d.default: CastPolicy.FP32,
125+
torch.ops.aten.replication_pad1d.default: CastPolicy.FP32,
126+
torch.ops.aten.replication_pad2d.default: CastPolicy.FP32,
127+
torch.ops.aten.replication_pad3d.default: CastPolicy.FP32,
128+
torch.ops.aten.mse_loss.default: CastPolicy.FP32,
129+
torch.ops.aten.cosine_embedding_loss.default: CastPolicy.FP32,
130+
torch.ops.aten.nll_loss.default: CastPolicy.FP32,
131+
torch.ops.aten.nll_loss2d.default: CastPolicy.FP32,
132+
torch.ops.aten.hinge_embedding_loss.default: CastPolicy.FP32,
133+
torch.ops.aten.poisson_nll_loss.default: CastPolicy.FP32,
134+
torch.ops.aten.smooth_l1_loss.default: CastPolicy.FP32,
135+
torch.ops.aten.cross_entropy_loss.default: CastPolicy.FP32,
136+
torch.ops.aten.l1_loss.default: CastPolicy.FP32,
137+
torch.ops.aten.huber_loss.default: CastPolicy.FP32,
138+
torch.ops.aten.margin_ranking_loss.default: CastPolicy.FP32,
139+
torch.ops.aten.soft_margin_loss.default: CastPolicy.FP32,
140+
torch.ops.aten.triplet_margin_loss.default: CastPolicy.FP32,
141+
torch.ops.aten.multi_margin_loss.default: CastPolicy.FP32,
142+
torch.ops.aten.ctc_loss.IntList: CastPolicy.FP32,
143+
torch.ops.aten.ctc_loss.Tensor: CastPolicy.FP32,
144+
torch.ops.aten.kl_div.default: CastPolicy.FP32,
145+
torch.ops.aten.multilabel_margin_loss.default: CastPolicy.FP32,
146+
torch.ops.aten.binary_cross_entropy_with_logits.default: CastPolicy.FP32,
147+
torch.ops.aten.fft_fft.default: CastPolicy.FP32,
148+
torch.ops.aten.fft_ifft.default: CastPolicy.FP32,
149+
torch.ops.aten.fft_fft2.default: CastPolicy.FP32,
150+
torch.ops.aten.fft_ifft2.default: CastPolicy.FP32,
151+
torch.ops.aten.fft_fftn.default: CastPolicy.FP32,
152+
torch.ops.aten.fft_ifftn.default: CastPolicy.FP32,
153+
torch.ops.aten.fft_rfft.default: CastPolicy.FP32,
154+
torch.ops.aten.fft_irfft.default: CastPolicy.FP32,
155+
torch.ops.aten.fft_rfft2.default: CastPolicy.FP32,
156+
torch.ops.aten.fft_irfft2.default: CastPolicy.FP32,
157+
torch.ops.aten.fft_rfftn.default: CastPolicy.FP32,
158+
torch.ops.aten.fft_irfftn.default: CastPolicy.FP32,
159+
torch.ops.aten.fft_hfft.default: CastPolicy.FP32,
160+
torch.ops.aten.fft_ihfft.default: CastPolicy.FP32,
161+
torch.ops.aten.linalg_cond.default: CastPolicy.FP32,
162+
torch.ops.aten.linalg_cond.p_str: CastPolicy.FP32,
163+
torch.ops.aten.linalg_matrix_rank.default: CastPolicy.FP32,
164+
torch.ops.aten.linalg_matrix_rank.tol_tensor: CastPolicy.FP32,
165+
torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor: CastPolicy.FP32,
166+
torch.ops.aten.linalg_matrix_rank.atol_rtol_float: CastPolicy.FP32,
167+
torch.ops.aten.linalg_solve.default: CastPolicy.FP32,
168+
torch.ops.aten.linalg_cholesky.default: CastPolicy.FP32,
169+
torch.ops.aten.linalg_svdvals.default: CastPolicy.FP32,
170+
torch.ops.aten.linalg_eigvals.default: CastPolicy.FP32,
171+
torch.ops.aten.linalg_eigvalsh.default: CastPolicy.FP32,
172+
torch.ops.aten.linalg_inv.default: CastPolicy.FP32,
173+
torch.ops.aten.linalg_householder_product.default: CastPolicy.FP32,
174+
torch.ops.aten.linalg_tensorinv.default: CastPolicy.FP32,
175+
torch.ops.aten.linalg_tensorsolve.default: CastPolicy.FP32,
176+
torch.ops.aten.fake_quantize_per_tensor_affine.default: CastPolicy.FP32,
177+
torch.ops.aten.geqrf.default: CastPolicy.FP32,
178+
torch.ops.aten._lu_with_info.default: CastPolicy.FP32,
179+
torch.ops.aten.qr.default: CastPolicy.FP32,
180+
torch.ops.aten.svd.default: CastPolicy.FP32,
181+
torch.ops.aten.triangular_solve.default: CastPolicy.FP32,
182+
torch.ops.aten.fractional_max_pool2d.default: CastPolicy.FP32,
183+
torch.ops.aten.fractional_max_pool3d.default: CastPolicy.FP32,
184+
torch.ops.aten.adaptive_max_pool3d.default: CastPolicy.FP32,
185+
torch.ops.aten.multilabel_margin_loss_forward.default: CastPolicy.FP32,
186+
torch.ops.aten.linalg_qr.default: CastPolicy.FP32,
187+
torch.ops.aten.linalg_cholesky_ex.default: CastPolicy.FP32,
188+
torch.ops.aten.linalg_svd.default: CastPolicy.FP32,
189+
torch.ops.aten.linalg_eig.default: CastPolicy.FP32,
190+
torch.ops.aten.linalg_eigh.default: CastPolicy.FP32,
191+
torch.ops.aten.linalg_lstsq.default: CastPolicy.FP32,
192+
torch.ops.aten.linalg_inv_ex.default: CastPolicy.FP32,
193+
194+
# promote
195+
torch.ops.aten.stack.default: CastPolicy.PROMOTE,
196+
torch.ops.aten.cat.default: CastPolicy.PROMOTE,
197+
torch.ops.aten.index_copy.default: CastPolicy.PROMOTE,
198+
torch.ops.aten.index_copy.dimname: CastPolicy.PROMOTE,
199+
}

torchax/torchax/tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torchax.view import View, NarrowInfo
1717
from torchax import config
1818
from torchax.ops import mappings, ops_registry
19+
from torchax import amp
1920

2021
logger = logging.getLogger(__name__)
2122

@@ -323,6 +324,7 @@ def __init__(self, configuration=None):
323324
self.enabled = False
324325
self._jax_devices = set(["jax", "jax_cpu", "xla"])
325326
self.prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
327+
self.autocast_dtype = None
326328

327329
def manual_seed(self, key):
328330
self.prng_key = jax.random.key(key)
@@ -512,6 +514,13 @@ def is_not_torchax_tensor(x):
512514
if not op.is_view_op:
513515
args, kwargs = self.v2t_iso((args, kwargs))
514516

517+
with self:
518+
if self.autocast_dtype is not None:
519+
autocast_policy = amp.autocast_policy.get(func)
520+
if autocast_policy is not None:
521+
args, kwargs = amp.execute_policy(
522+
autocast_policy, args, kwargs, self.autocast_dtype)
523+
515524
if op.is_jax_function:
516525
args, kwargs = self.t2j_iso((args, kwargs))
517526
except AssertionError:

0 commit comments

Comments
 (0)