diff --git a/README.md b/README.md index 92a2e3a..1b6ebf3 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ def f(x): x.mul_(2) return x -f(torch.Tensor([3])) # => torch.Tensor([8]) +f(torch.tensor([3])) # => torch.Tensor([8]) jax_f = t2j(f) jax_f(jnp.array([3])) # => jnp.array([8]) @@ -69,17 +69,47 @@ torch2jax is an implementation of the PyTorch standard library written in JAX. I Adding new PyTorch operations is straightforward. Check the source for functions decorated with `@implements` to get started. -### My PyTorch model includes dropout (or some other random operation), and does not work in training mode. Why? +### My PyTorch model includes dropout or some other random operation. How does this work with torch2jax? -JAX mandates [deterministic randomness](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html), while PyTorch does not. This leads to some API friction. torch2jax does not currently offer a means to bridge this gap. I have an idea for how to accomplish it. If this is important to you, please open an issue. +Pass a `jax.random.PRNGKey` to the converted function: -In the meantime, make sure to call `.eval()` on your `torch.nn.Module` before conversion. +```python +t2j(lambda: torch.randn(3))(rng=jax.random.PRNGKey(123)) +# => [-0.56996626, -0.6440589 , 0.28660855] + +t2j(lambda: torch.randn(3))(rng=jax.random.PRNGKey(456)) +# => [-1.3227656, -1.4896724, -2.5057693] +``` + +After conversion, random state will be handled entirely in JAX. `torch.manual_seed` and its ilk will have no effect on the converted function. + +If you only care about running a model and not training it, you can call `.eval()` on it to avoid the randomness issue altogether, at least for most common random operations like dropout: + +```python +rn18 = torchvision.models.resnet18().eval() +t2j(rn18)(t2j(torch.randn(1, 3, 224, 224))) # Look ma, no `rng` kwarg! +``` -### My PyTorch model includes batch norm (or some other `torch.nn.Module` utilizing buffers), and does not work in training mode. What can I do? +> [!NOTE] +> Non-deterministic behavior is, well, non-deterministic. You will not see the same results with the same random seed when switching between PyTorch and JAX. However, the sampling process _will_ be equivalent. -Similar to the randomness story, PyTorch and JAX have different approaches to maintaining state. Operations like batch norm require maintaining running statistics. In PyTorch, this is accomplished via [buffers](https://stackoverflow.com/questions/57540745/what-is-the-difference-between-register-parameter-and-register-buffer-in-pytorch/57546078#57546078). +### My PyTorch model includes batch norm or some other `torch.nn.Module` that mutates buffers. How does this work with torch2jax? -torch2jax supports running batch norm models in `eval()`-mode. Just don't forget that you should avoid taking gradients w.r.t. buffers. For example, +Some PyTorch modules like `torch.nn.BatchNorm1d` mutate internal state in the form of [buffers](https://stackoverflow.com/questions/57540745/what-is-the-difference-between-register-parameter-and-register-buffer-in-pytorch/57546078#57546078). + +torch2jax supports this with the optional `return_state_dict` argument: + +```python +rn18 = torchvision.models.resnet18() +batch = torch.randn(1, 3, 224, 224) + +before_state_dict = {k: t2j(v) for k, v in rn18.state_dict().items()} +out, after_state_dict = t2j(rn18)(t2j(batch), state_dict=before_state_dict, return_state_dict=True) +``` + +As with randomness, if you only care about running a model and not training it, you can call `.eval()` on it to avoid buffer issues altogether in most cases. + +Also, don't forget to avoid taking gradients w.r.t. buffers. For example, ```python rn18 = torchvision.models.resnet18().eval() @@ -95,13 +125,11 @@ jax_rn18 = t2j(rn18) grad(lambda params, x: loss(jax_rn18(x, state_dict={**params, **buffers})))(parameters, t2j(batch)) ``` -I have an idea for how to implement buffers, including in training mode. If this is important to you, please open an issue. - ### I'm seeing slightly different numerical results between PyTorch and JAX. Is it a bug? Floating point arithmetic is hard. There are a number of sources of divergence preventing bit-for-bit equivalence: -1. torch2jax guarantees equivalence with PyTorch standard library functions in the mathematical sense, but not necessarily in their operational execution. This can lead to slight differences in results. +1. torch2jax guarantees equivalence with PyTorch standard library functions in the mathematical sense, but not necessarily in their operational execution. This can lead to slight differences in results. For example, the multi-head attention implementations calculate the same mathematical function, but may vary in execution details such as the order of operations, the use of fused kernels, and so forth. 2. The JAX/XLA and PyTorch compilers apply different optimizations and should be expected to rewrite computation graphs in exciting and unpredictable ways, potentially invoking different CUDA kernels. 3. CUDA kernels can be non-deterministic, for example as a result of floating point addition being non-associative.